Skip to content

Commit a64d2c2

Browse files
authored
Workaround issue 1802 (#1800)
Workaround for issue 1802
1 parent 1fb495e commit a64d2c2

File tree

5 files changed

+49
-16
lines changed

5 files changed

+49
-16
lines changed

mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,6 +1760,13 @@ struct GridwiseAttentionAccelRewritePattern
17601760
if (doBypassLDSForQ) {
17611761
ldsLayoutCfgNG0.doSwapThreadIterSubDims = false;
17621762
}
1763+
#ifndef ROCK_DEBUG_ATTENTION_REMOVE_SOFTMAX
1764+
// TODO: Workaround for issue
1765+
// https://github.com/ROCm/rocMLIR-internal/issues/1802 If sumRowBuffer and
1766+
// expMaxDiffRowBuffer are filled with doSwapThreadIterSubDims=true, it does
1767+
// not match with the second GEMM N dimension. Find a good solution to this.
1768+
ldsLayoutCfgNG0.doSwapThreadIterSubDims = false;
1769+
#endif
17631770
FailureOr<VectorDimInfo> maybeVectorDimInfoK =
17641771
getVectorDim(rewriter, loc, inK, elemTypeK, blockSize, gemm0KPerBlock,
17651772
gemm0MPerBlock, gemm0kpack);
@@ -1834,7 +1841,7 @@ struct GridwiseAttentionAccelRewritePattern
18341841
Value accRegBufferGemm0 =
18351842
createBufferForAccelGemmOut(loc, accelParamsGemm0, rewriter);
18361843
// Currently, there is a working assumption that this kernel is meant
1837-
// support fp32/fp16 This should be guranteed by op verifiers.
1844+
// support fp32/fp16/bf16. This should be guranteed by op verifiers.
18381845
Type gemmOutElemType = elemTypeQxK;
18391846
Type softmaxInElemType = elemTypeQxK;
18401847
if (elemTypeQ == rewriter.getI8Type()) {
@@ -1985,12 +1992,13 @@ struct GridwiseAttentionAccelRewritePattern
19851992
if (failed(statusLoadQ)) {
19861993
return failure();
19871994
}
1995+
rewriter.create<LDSBarrierOp>(loc);
19881996

19891997
TypedValue<MemRefType> ldsTileBufferQ = viewBufferAs(
19901998
rewriter, ldsByteBufferQ, vectorTypeOrSelf(elemTypeQ, gemm0kpack));
19911999
loadGemmOperandsFromLDSToRegs(
19922000
rewriter, loc, ldsTileBufferQ, preAccelRegBuffersQ, "n", blockSize,
1993-
gemm0InMPerThread, *accelEmitterPtrGemm0.get(),
2001+
gemm0InNPerThread, *accelEmitterPtrGemm0.get(),
19942002
ldsLayoutCfgNG0.doRotateWithK);
19952003
rewriter.create<GpuDeallocOp>(loc, ldsByteBufferQ);
19962004
}

mlir/test/e2e/PrAttentionBF16.toml

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,25 @@ prefix = "rocmlir-gen"
33
suffix = "--operation attention -t bf16 --arch %arch -pv %random_data %rocmlir_gen_flags -relDiff_threshold 0.3 -absDiff_threshold 0.3 -RMS_threshold 0.15 | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix="
44

55
[[axis]]
6-
name = "operation"
7-
values = ["attention"]
8-
prefix = "--operation "
6+
name = "transQ"
7+
values = ["true", "false"]
8+
prefix = "--transQ="
99

1010
[[axis]]
1111
name = "transK"
1212
values = ["true", "false"]
1313
prefix = "--transK="
1414

15+
[[axis]]
16+
name = "transV"
17+
values = ["true", "false"]
18+
prefix = "--transV="
19+
20+
[[axis]]
21+
name = "transO"
22+
values = ["true", "false"]
23+
prefix = "--transO="
24+
1525
## attention variant
1626
[[suite]]
1727
name = "pr_attention_bf16"

mlir/test/e2e/PrAttentionF16.toml

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,25 @@ prefix = "rocmlir-gen"
33
suffix = "--operation attention -t f16 --arch %arch -pv %random_data %rocmlir_gen_flags -relDiff_threshold 0.02 -absDiff_threshold 0.1 -RMS_threshold 0.015 | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix="
44

55
[[axis]]
6-
name = "operation"
7-
values = ["attention"]
8-
prefix = "--operation "
6+
name = "transQ"
7+
values = ["true", "false"]
8+
prefix = "--transQ="
99

1010
[[axis]]
1111
name = "transK"
1212
values = ["true", "false"]
1313
prefix = "--transK="
1414

15+
[[axis]]
16+
name = "transV"
17+
values = ["true", "false"]
18+
prefix = "--transV="
19+
20+
[[axis]]
21+
name = "transO"
22+
values = ["true", "false"]
23+
prefix = "--transO="
24+
1525
## attention variant
1626
[[suite]]
1727
name = "pr_attention_f16"

mlir/test/e2e/PrAttentionF32.toml

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,25 @@ prefix = "rocmlir-gen"
33
suffix = "--operation attention -t f32 --arch %arch -pv %random_data %rocmlir_gen_flags -relDiff_threshold 0.00005 | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix="
44

55
[[axis]]
6-
name = "operation"
7-
values = ["attention"]
8-
prefix = "--operation "
6+
name = "transQ"
7+
values = ["true", "false"]
8+
prefix = "--transQ="
99

1010
[[axis]]
1111
name = "transK"
1212
values = ["true", "false"]
1313
prefix = "--transK="
1414

15+
[[axis]]
16+
name = "transV"
17+
values = ["true", "false"]
18+
prefix = "--transV="
19+
20+
[[axis]]
21+
name = "transO"
22+
values = ["true", "false"]
23+
prefix = "--transO="
24+
1525
## attention variant
1626
[[suite]]
1727
name = "pr_attention_f32"

mlir/test/e2e/PrAttentionI8.toml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,6 @@ directory = "PrAttentionI8"
22
prefix = "rocmlir-gen"
33
suffix = "--operation attention -t i8 --arch %arch -pv %random_data %rocmlir_gen_flags -RMS_threshold 0.01 | rocmlir-driver -c | mlir-runner -O2 --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext --entry-point-result=void | FileCheck %s --check-prefix="
44

5-
[[axis]]
6-
name = "operation"
7-
values = ["attention"]
8-
prefix = "--operation "
9-
105
[[axis]]
116
name = "transK"
127
values = ["true", "false"]

0 commit comments

Comments
 (0)