Skip to content

Commit 7dbaad3

Browse files
authored
Fix bug in DPAS analysis. (#4252)
Fixes the bug in DPAS analysis that the code missed some dot ops in the function. Signed-off-by: Lu,Chengjun <[email protected]>
1 parent dcab8e3 commit 7dbaad3

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

test/TritonIntelGPU/accelerate-matmul-pvc.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,24 @@ module attributes {ttg.target = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warps"
312312
tt.return
313313
}
314314
}
315+
316+
// -----
317+
318+
// CHECK-NOT: triton_intel_gpu.dpas
319+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
320+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32, "triton_intel_gpu.min_sg_size" = 16 : i32, "triton_intel_gpu.support_dpas"} {
321+
// CHECK-LABEL: check_dpas_cap
322+
tt.func @check_dpas_cap(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
323+
%zero_f32 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #blocked>
324+
%a = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
325+
%b = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>
326+
327+
%result = tt.dot %a, %b, %zero_f32, inputPrecision = tf32 : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf32, #blocked>
328+
%result_ptr = tt.splat %arg0 : !tt.ptr<f32> -> tensor<128x16x!tt.ptr<f32>, #blocked>
329+
tt.store %result_ptr, %result : tensor<128x16x!tt.ptr<f32>, #blocked>
330+
331+
%result2 = tt.dot %a, %b, %zero_f32 : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x16xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf32, #blocked>
332+
tt.store %result_ptr, %result2 : tensor<128x16x!tt.ptr<f32>, #blocked>
333+
tt.return
334+
}
335+
}

third_party/intel/lib/Analysis/DPAS.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,15 @@ DPASAnalysis::DPASAnalysis(Operation *root) {
1717

1818
// Populate the maps.
1919
mod.walk([&](FunctionOpInterface funcOp) {
20+
if (funcToDotMap.find(funcOp) == funcToDotMap.end())
21+
funcToDotMap[funcOp] = {};
2022
auto it = funcToDotMap.find(funcOp);
2123

2224
funcOp.walk([&](Operation *op) {
2325
if (!isa<DotOp, DotScaledOp>(op))
2426
return;
2527

26-
if (it != funcToDotMap.end())
27-
it->second.push_back(op);
28-
else
29-
funcToDotMap[funcOp] = {op};
28+
it->second.push_back(op);
3029

3130
DPASEngineType dpasEngineType = supportDPAS
3231
? DPASAnalysis::getDPASType(op)

0 commit comments

Comments
 (0)