Skip to content

Commit cf664ef

Browse files
committed
address review comments
1 parent e5d61ee commit cf664ef

File tree

2 files changed

+15
-36
lines changed

2 files changed

+15
-36
lines changed

test/TritonIntelGPU/tritonintelgpu-invalid.mlir

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -163,22 +163,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr
163163
// -----
164164

165165
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}>
166-
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
167-
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
168166
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}>
169167
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
170168
#smem = #ttg.shared_memory
171169
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
172170
// CHECK-LABEL: matmul_tf32dot
173-
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
174-
%a:!ttg.memdesc<32x16xf32, #shared, #smem>, %b:!ttg.memdesc<16x32xf32, #shared, #smem>) {
171+
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32>,
172+
%a_mat:tensor<32x16xf32, #dot_operand_a>, %b_mat:tensor<16x32xf32, #dot_operand_b>) {
175173
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas>
176-
%a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a>
177-
%b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b>
178174

179175
// expected-error @+1 {{Layout has opsPerChannel = 2 but tensor element type is 'f32'. Expected 16 bit type.}}
180176
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas>
181-
%38 = ttg.convert_layout %28 : tensor<32x32xf32, #dpas> -> tensor<32x32xf32, #blocked>
182177

183178
tt.return
184179
}
@@ -187,22 +182,16 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
187182
// -----
188183

189184
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}>
190-
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
191-
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
192185
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}>
193186
// expected-error @below {{ttg.dot_op kWidth parameter must match the parent's opsPerChannel}}
194187
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
195188
#smem = #ttg.shared_memory
196189
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
197190
// CHECK-LABEL: matmul_tf32dot
198-
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
199-
%a:!ttg.memdesc<32x16xf32, #shared, #smem>, %b:!ttg.memdesc<16x32xf32, #shared, #smem>) {
191+
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32>,
192+
%a_mat:tensor<32x16xf32, #dot_operand_a>, %b_mat:tensor<16x32xf32, #dot_operand_b>) {
200193
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas>
201-
%a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a>
202-
%b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b>
203-
204194
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas>
205-
%38 = ttg.convert_layout %28 : tensor<32x32xf32, #dpas> -> tensor<32x32xf32, #blocked>
206195

207196
tt.return
208197
}

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,40 +1222,30 @@ struct TritonIntelGPUVerifyTensorLayoutInterface
12221222
if (auto ptrTy = dyn_cast<PointerType>(elemTy)) {
12231223
elemTy = ptrTy.getPointeeType();
12241224
}
1225-
const auto elemTyBitWidth = elemTy.getIntOrFloatBitWidth();
1225+
const unsigned elemTyBitWidth = elemTy.getIntOrFloatBitWidth();
12261226

12271227
// We know opsPerChannel is either 1, 4, or 8 because of the DPAS
12281228
// verifier when the DPAS attribute is created. Here we verify that
12291229
// opsPerChannel matches the tensor type.
1230-
if (dpasEncoding.getOpsPerChannel() == 4 && elemTyBitWidth != 8) {
1230+
if (dpasEncoding.getOpsPerChannel() * elemTyBitWidth != 32) {
12311231
return makeErr() << layout << ".\nLayout has opsPerChannel = "
12321232
<< dpasEncoding.getOpsPerChannel()
12331233
<< " but tensor element type is " << elemTy
1234-
<< ". Expected 8 bit type.";
1235-
} else if (dpasEncoding.getOpsPerChannel() == 2 &&
1236-
elemTyBitWidth != 16) {
1237-
return makeErr() << layout << ".\nLayout has opsPerChannel = "
1238-
<< dpasEncoding.getOpsPerChannel()
1239-
<< " but tensor element type is " << elemTy
1240-
<< ". Expected 16 bit type.";
1241-
} else if (dpasEncoding.getOpsPerChannel() == 1 &&
1242-
elemTyBitWidth != 32) {
1243-
return makeErr() << layout << ".\nLayout has opsPerChannel = "
1244-
<< dpasEncoding.getOpsPerChannel()
1245-
<< " but tensor element type is " << elemTy
1246-
<< ". Expected 32 bit type.";
1234+
<< ". Expected "
1235+
<< 32 / dpasEncoding.getOpsPerChannel()
1236+
<< " bit type.";
12471237
}
12481238
return success();
12491239
};
12501240

1251-
if (isa<DotOp>(op)) {
1252-
auto dotOp = cast<DotOp>(op);
1241+
if (auto dotOp = dyn_cast<DotOp>(op)) {
12531242
auto aElemTy = dotOp.getA().getType().getElementType();
1254-
auto result = validateDotDpasLayout(aElemTy);
1255-
if (result.failed())
1256-
return result;
1257-
12581243
auto bElemTy = dotOp.getB().getType().getElementType();
1244+
1245+
auto aResult = validateDotDpasLayout(aElemTy);
1246+
if (aResult.failed())
1247+
return aResult;
1248+
12591249
return validateDotDpasLayout(bElemTy);
12601250
}
12611251
}

0 commit comments

Comments
 (0)