Skip to content

Commit b5c46f0

Browse files
[AccelerateMatmul] Fix getWarpsPerTile with rank > 2 (#5247)
Fixes pytorch/helion#772, #5246 Signed-off-by: Whitney Tsang <[email protected]>
1 parent 63ad1ad commit b5c46f0

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

test/TritonIntelGPU/accelerate-matmul-pvc.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,3 +417,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
417417
tt.return
418418
}
419419
}
420+
421+
// -----
422+
423+
// CHECK: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1, 1], repCluster = [1, 4, 2], A = [1, 32, 16], B = [1, 16, 32], C = [1, 32, 32]}>
424+
#blocked = #ttg.blocked<{sizePerThread = [1, 4, 4], threadsPerWarp = [1, 1, 16], warpsPerCTA = [1, 4, 1], order = [2, 1, 0]}>
425+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.min_sg_size" = 16 : i32, ttig.support_dpas} {
426+
tt.func public @_helion_repro_baddbmm_kernel(%A: tensor<1x64x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>, %B: tensor<1x64x64xbf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>, %C: tensor<1x64x64x!tt.ptr<bf16>, #blocked>) {
427+
%cst = arith.constant dense<0.000000e+00> : tensor<1x64x64xf32, #blocked>
428+
// CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<1x64x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<1x64x64xbf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<1x64x64xf32, #[[$DPAS]]>
429+
%31 = tt.dot %A, %B, %cst, inputPrecision = tf32 : tensor<1x64x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<1x64x64xbf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<1x64x64xf32, #blocked>
430+
%39 = arith.truncf %31 : tensor<1x64x64xf32, #blocked> to tensor<1x64x64xbf16, #blocked>
431+
%40 = ttg.convert_layout %39 : tensor<1x64x64xbf16, #blocked> -> tensor<1x64x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
432+
// CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<1x64x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>> * tensor<1x64x64xbf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>> -> tensor<1x64x64xf32, #[[$DPAS]]>
433+
%42 = tt.dot %40, %B, %cst, inputPrecision = tf32 : tensor<1x64x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<1x64x64xbf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<1x64x64xf32, #blocked>
434+
%43 = arith.truncf %42 : tensor<1x64x64xf32, #blocked> to tensor<1x64x64xbf16, #blocked>
435+
tt.store %C, %43 : tensor<1x64x64x!tt.ptr<bf16>, #blocked>
436+
tt.return
437+
}
438+
}

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ getWarpsPerTile(tt::DotOp dotOp, ttgi::DpasEncodingAttr::DPASCapability dpasCap,
6363
setAttrOnBOperand(dotOp, attrName, UnitAttr::get(ctx));
6464
setAttrOnBOperand(cast<tt::DotOp>(op), attrName, UnitAttr::get(ctx));
6565
}
66-
return {numWarps, 1};
66+
SmallVector<unsigned> ret(shape.size(), 1);
67+
ret[0] = numWarps;
68+
return ret;
6769
}
6870
}
6971

0 commit comments

Comments
 (0)