Skip to content

Commit b2fcf64

Browse files
authored
[FlexDecoding] Support M < 8 tt.dot with DPAS to optimize the flex decoding performance. (#4727)
Use M size as the repeat count if M < 8. It can help to reduce the number of duplicated redundant value of DotOp and DPAS layout. Signed-off-by: Lu,Chengjun <[email protected]>
1 parent b7ab0b8 commit b2fcf64

File tree

4 files changed

+43
-6
lines changed

4 files changed

+43
-6
lines changed

python/test/unit/language/test_core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3927,7 +3927,7 @@ def get_test_dot_vdot2_cases():
39273927

39283928

39293929
def get_test_small_dots_cases():
3930-
if not is_cuda():
3930+
if not (is_cuda() or is_xpu()):
39313931
return []
39323932
return [(2, 4, 32, 1, False, False, 'None', 'ieee', 'float16', 'float32', 1, None),
39333933
(1, 2, 32, 1, False, False, 'None', 'ieee', 'float8e5', 'float32', 1, None)]
@@ -6211,6 +6211,8 @@ def kernel(Out):
62116211
dim=1, parent=DotOperandLayout(parent=MmaLayout([2, 0], [4, 1, 1], [1, 1, 1], [1, 1, 1], [2, 1, 0], [1, 16, 8]),
62126212
op_idx=1, k_width=2)),
62136213
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
6214+
warps_per_cta=[4, 1], rep_cluster=[1, 1]),
6215+
DpasLayout(repeatCount=2, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
62146216
warps_per_cta=[4, 1], rep_cluster=[1, 1])
62156217
]
62166218

test/TritonIntelGPU/accelerate-matmul-pvc.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,38 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
333333
tt.return
334334
}
335335
}
336+
337+
// -----
338+
339+
// CHECK: #[[$DPAS0:.+]] = #ttig.dpas<{repeatCount = 1, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1], A = [1, 16], B = [16, 16], C = [1, 16]}>
340+
// CHECK: #[[$DPAS1:.+]] = #ttig.dpas<{repeatCount = 2, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1], A = [2, 16], B = [16, 16], C = [2, 16]}>
341+
// CHECK: #[[$DPAS2:.+]] = #ttig.dpas<{repeatCount = 4, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1], A = [4, 16], B = [16, 16], C = [4, 16]}>
342+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
343+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.min_sg_size" = 16 : i32, "ttig.support_dpas"} {
344+
tt.func @M_smaller_than_8(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
345+
// CHECK-LABEL: M_smaller_than_8
346+
%b = arith.constant dense<0.000000e+00> : tensor<128x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>
347+
348+
// CHECK: tt.dot {{.*}} -> tensor<1x16xf32, #[[$DPAS0]]>
349+
%a0 = arith.constant dense<0.000000e+00> : tensor<1x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
350+
%zero0 = arith.constant dense<0.000000e+00> : tensor<1x16xf32, #blocked>
351+
%result0 = tt.dot %a0, %b, %zero0 : tensor<1x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<1x16xf32, #blocked>
352+
%result_ptr0 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1x16x!tt.ptr<f32>, #blocked>
353+
tt.store %result_ptr0, %result0 : tensor<1x16x!tt.ptr<f32>, #blocked>
354+
355+
// CHECK: tt.dot {{.*}} -> tensor<2x16xf32, #[[$DPAS1]]>
356+
%a1 = arith.constant dense<0.000000e+00> : tensor<2x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
357+
%zero1 = arith.constant dense<0.000000e+00> : tensor<2x16xf32, #blocked>
358+
%result1 = tt.dot %a1, %b, %zero1 : tensor<2x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<2x16xf32, #blocked>
359+
%result_ptr1 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<2x16x!tt.ptr<f32>, #blocked>
360+
tt.store %result_ptr1, %result1 : tensor<2x16x!tt.ptr<f32>, #blocked>
361+
362+
// CHECK: tt.dot {{.*}} -> tensor<4x16xf32, #[[$DPAS2]]>
363+
%a2 = arith.constant dense<0.000000e+00> : tensor<4x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
364+
%zero2 = arith.constant dense<0.000000e+00> : tensor<4x16xf32, #blocked>
365+
%result2 = tt.dot %a2, %b, %zero2 : tensor<4x128xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x16xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<4x16xf32, #blocked>
366+
%result_ptr2 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<4x16x!tt.ptr<f32>, #blocked>
367+
tt.store %result_ptr2, %result2 : tensor<4x16x!tt.ptr<f32>, #blocked>
368+
tt.return
369+
}
370+
}

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,7 @@ def min_dot_size(device_props: dict):
6767
# M: repeatCount. 1,2,4,8
6868
# N: executionSize. 16 for PVC, 8 for ATS
6969
# K: systolicDepth x opsPerChan. systolicDepth must be 8
70-
71-
# default 8 because 1,2,4 is not supported by our backend now.
72-
repeat_count = 8
70+
repeat_count = 1
7371
sdepth = 8
7472
exec_size = min(device_props["sub_group_sizes"])
7573

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,11 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
123123
size_t rank = retShape.size();
124124
SmallVector<unsigned> repCluster(rank, 1);
125125

126+
unsigned repeatCount =
127+
std::min(dpasCap.repeatCount, (unsigned)retShape[rank - 2] /*M*/);
126128
unsigned threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod);
127129
auto dpasEnc = ttgi::DpasEncodingAttr::get(
128-
oldRetType.getContext(), dpasCap.repeatCount, dpasCap.systolicDepth,
130+
oldRetType.getContext(), repeatCount, dpasCap.systolicDepth,
129131
dpasCap.executionSize, opsPerChan, warpsPerTile, repCluster,
130132
threadsPerWarp);
131133

@@ -157,7 +159,7 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
157159
repCluster[rank - 1] = repClusterDimN;
158160

159161
dpasEnc = ttgi::DpasEncodingAttr::get(
160-
oldRetType.getContext(), dpasCap.repeatCount, dpasCap.systolicDepth,
162+
oldRetType.getContext(), repeatCount, dpasCap.systolicDepth,
161163
dpasCap.executionSize, opsPerChan, warpsPerTile, repCluster,
162164
threadsPerWarp);
163165
}

0 commit comments

Comments
 (0)