Skip to content

Commit 50fc4c3

Browse files
authored
Add Tensor Layout verifier for DPAS layout (#4339)
Introduces a verifier to ensure the DPAS layout attached to a Dot operation has a suitable opsPerChannel param for the A and B operand inputs to the Dot op. Previously this verification was implicit in the Triton GEN verification, producing a somewhat cryptic error message (prior to #4276 there was no error message): ``` test.mlir:16:11: error: 'triton_gen.dpas' op the dimension for the 2nd operand (A) should be equal to half of the repeat count %28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #3935 ``` with this new verifier, the error message is more user friendly: ``` test.mlir:16:11: error: unexpected error: Operand 2 (%0 = "arith.constant"() <{value = dense<0.000000e+00> : tensor<32x32xf32, #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>>}> : () -> tensor<32x32xf32, #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>>) has an invalid layout: #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>. Layout has opsPerChannel = 2 but tensor element type is 'f32'. Expected 16 bit type. %28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas> ``` close #4270
1 parent 92e9426 commit 50fc4c3

File tree

2 files changed

+108
-0
lines changed

2 files changed

+108
-0
lines changed

test/TritonIntelGPU/tritonintelgpu-invalid.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,59 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr
158158
tt.return %res : tensor<8x16xf16>
159159
}
160160
}
161+
162+
163+
// -----
164+
165+
// COM: A operand mismatch
166+
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}>
167+
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}>
168+
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
169+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
170+
// CHECK-LABEL: matmul_tf32dot
171+
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32>,
172+
%a_mat:tensor<32x16xf32, #dot_operand_a>, %b_mat:tensor<16x32xf32, #dot_operand_b>) {
173+
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas>
174+
175+
// expected-error @+1 {{Layout has opsPerChannel = 2 but tensor element type is 'f32'. Expected 16 bit type.}}
176+
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas>
177+
178+
tt.return
179+
}
180+
}
181+
182+
// -----
183+
184+
// COM: B operand mismatch
185+
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}>
186+
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}>
187+
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
188+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
189+
// CHECK-LABEL: matmul_tf32dot
190+
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32>,
191+
%a_mat:tensor<32x16xf16, #dot_operand_a>, %b_mat:tensor<16x32xf32, #dot_operand_b>) {
192+
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas>
193+
194+
// expected-error @+1 {{Layout has opsPerChannel = 2 but tensor element type is 'f32'. Expected 16 bit type.}}
195+
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf16, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas>
196+
197+
tt.return
198+
}
199+
}
200+
201+
// -----
202+
203+
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}>
204+
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}>
205+
// expected-error @below {{ttg.dot_op kWidth parameter must match the parent's opsPerChannel}}
206+
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
207+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
208+
// CHECK-LABEL: matmul_tf32dot
209+
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32>,
210+
%a_mat:tensor<32x16xf32, #dot_operand_a>, %b_mat:tensor<16x32xf32, #dot_operand_b>) {
211+
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas>
212+
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas>
213+
214+
tt.return
215+
}
216+
}

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,57 @@ struct TritonIntelGPUInferLayoutInterface
12031203
}
12041204
};
12051205

1206+
struct TritonIntelGPUVerifyTensorLayoutInterface
1207+
: public triton::DialectVerifyTensorLayoutInterface {
1208+
using DialectVerifyTensorLayoutInterface::DialectVerifyTensorLayoutInterface;
1209+
1210+
LogicalResult verifyTensorLayout(
1211+
Attribute layout, RankedTensorType rankedTy, Operation *op,
1212+
function_ref<InFlightDiagnostic()> makeErr) const override {
1213+
1214+
// Verify that the DPAS layout opsPerChannel param matches the A and B
1215+
// operand types. Because the DotOperand layout is not part of the Triton
1216+
// Intel GPU dialect, we need to first check for a tt.dot operation. Then,
1217+
// we can compare the type of each operand to the Dot operation with the
1218+
// DPAS layout attached to the Dot operation.
1219+
if (auto dpasEncoding = dyn_cast<DpasEncodingAttr>(layout)) {
1220+
1221+
auto validateDotDpasLayout = [&](Type elemTy) -> LogicalResult {
1222+
if (auto ptrTy = dyn_cast<PointerType>(elemTy)) {
1223+
elemTy = ptrTy.getPointeeType();
1224+
}
1225+
const unsigned elemTyBitWidth = elemTy.getIntOrFloatBitWidth();
1226+
1227+
// We know opsPerChannel is either 1, 4, or 8 because of the DPAS
1228+
// verifier when the DPAS attribute is created. Here we verify that
1229+
// opsPerChannel matches the tensor type.
1230+
if (dpasEncoding.getOpsPerChannel() * elemTyBitWidth != 32) {
1231+
return makeErr() << layout << ".\nLayout has opsPerChannel = "
1232+
<< dpasEncoding.getOpsPerChannel()
1233+
<< " but tensor element type is " << elemTy
1234+
<< ". Expected "
1235+
<< 32 / dpasEncoding.getOpsPerChannel()
1236+
<< " bit type.";
1237+
}
1238+
return success();
1239+
};
1240+
1241+
if (auto dotOp = dyn_cast<DotOp>(op)) {
1242+
auto aElemTy = dotOp.getA().getType().getElementType();
1243+
auto bElemTy = dotOp.getB().getType().getElementType();
1244+
1245+
auto aResult = validateDotDpasLayout(aElemTy);
1246+
if (aResult.failed())
1247+
return aResult;
1248+
1249+
return validateDotDpasLayout(bElemTy);
1250+
}
1251+
}
1252+
1253+
return success();
1254+
}
1255+
};
1256+
12061257
//===----------------------------------------------------------------------===//
12071258

12081259
void TritonIntelGPUDialect::initialize() {
@@ -1212,6 +1263,7 @@ void TritonIntelGPUDialect::initialize() {
12121263
>();
12131264

12141265
addInterfaces<TritonIntelGPUInferLayoutInterface>();
1266+
addInterfaces<TritonIntelGPUVerifyTensorLayoutInterface>();
12151267

12161268
addOperations<
12171269
#define GET_OP_LIST

0 commit comments

Comments
 (0)