Skip to content

Commit 48f361e

Browse files
committed
Add Tensor Layout verifier for DPAS layout
1 parent d9fffb7 commit 48f361e

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

test/TritonIntelGPU/tritonintelgpu-invalid.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,52 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr
158158
tt.return %res : tensor<8x16xf16>
159159
}
160160
}
161+
162+
163+
// -----
164+
165+
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}>
166+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
167+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
168+
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}>
169+
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
170+
#smem = #ttg.shared_memory
171+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
172+
// CHECK-LABEL: matmul_tf32dot
173+
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
174+
%a:!ttg.memdesc<32x16xf32, #shared, #smem>, %b:!ttg.memdesc<16x32xf32, #shared, #smem>) {
175+
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas>
176+
%a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a>
177+
%b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b>
178+
179+
// expected-error @+1 {{Layout has opsPerChannel = 2 but tensor element type is 'f32'. Expected 16 bit type.}}
180+
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas>
181+
%38 = ttg.convert_layout %28 : tensor<32x32xf32, #dpas> -> tensor<32x32xf32, #blocked>
182+
183+
tt.return
184+
}
185+
}
186+
187+
// -----
188+
189+
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}>
190+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
191+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
192+
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}>
193+
// expected-error @below {{ttg.dot_op kWidth parameter must match the parent's opsPerChannel}}
194+
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
195+
#smem = #ttg.shared_memory
196+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
197+
// CHECK-LABEL: matmul_tf32dot
198+
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
199+
%a:!ttg.memdesc<32x16xf32, #shared, #smem>, %b:!ttg.memdesc<16x32xf32, #shared, #smem>) {
200+
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas>
201+
%a_mat = ttg.local_load %a : !ttg.memdesc<32x16xf32, #shared, #smem> -> tensor<32x16xf32, #dot_operand_a>
202+
%b_mat = ttg.local_load %b : !ttg.memdesc<16x32xf32, #shared, #smem> -> tensor<16x32xf32, #dot_operand_b>
203+
204+
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas>
205+
%38 = ttg.convert_layout %28 : tensor<32x32xf32, #dpas> -> tensor<32x32xf32, #blocked>
206+
207+
tt.return
208+
}
209+
}

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,67 @@ struct TritonIntelGPUInferLayoutInterface
948948
}
949949
};
950950

951+
struct TritonIntelGPUVerifyTensorLayoutInterface
952+
: public triton::DialectVerifyTensorLayoutInterface {
953+
using DialectVerifyTensorLayoutInterface::DialectVerifyTensorLayoutInterface;
954+
955+
LogicalResult verifyTensorLayout(
956+
Attribute layout, RankedTensorType rankedTy, Operation *op,
957+
function_ref<InFlightDiagnostic()> makeErr) const override {
958+
959+
// Verify that the DPAS layout opsPerChannel param matches the A and B
960+
// operand types. Because the DotOperand layout is not part of the Triton
961+
// Intel GPU dialect, we need to first check for a TT.Dot operation. Then,
962+
// we can compare the type of each operand to the Dot operation with the
963+
// DPAS layout attached to the Dot operation.
964+
if (auto dpasEncoding = dyn_cast<DpasEncodingAttr>(layout)) {
965+
966+
auto validateDotDpasLayout = [&](Type elemTy) -> LogicalResult {
967+
if (auto ptrTy = dyn_cast<PointerType>(elemTy)) {
968+
elemTy = ptrTy.getPointeeType();
969+
}
970+
const auto elemTyBitWidth = elemTy.getIntOrFloatBitWidth();
971+
972+
// We know opsPerChannel is either 1, 4, or 8 because of the DPAS
973+
// verifier when the DPAS attribute is created. Here we verify that
974+
// opsPerChannel matches the tensor type.
975+
if (dpasEncoding.getOpsPerChannel() == 4 && elemTyBitWidth != 8) {
976+
return makeErr() << layout << ".\nLayout has opsPerChannel = "
977+
<< dpasEncoding.getOpsPerChannel()
978+
<< " but tensor element type is " << elemTy
979+
<< ". Expected 8 bit type.";
980+
} else if (dpasEncoding.getOpsPerChannel() == 2 &&
981+
elemTyBitWidth != 16) {
982+
return makeErr() << layout << ".\nLayout has opsPerChannel = "
983+
<< dpasEncoding.getOpsPerChannel()
984+
<< " but tensor element type is " << elemTy
985+
<< ". Expected 16 bit type.";
986+
} else if (dpasEncoding.getOpsPerChannel() == 1 &&
987+
elemTyBitWidth != 32) {
988+
return makeErr() << layout << ".\nLayout has opsPerChannel = "
989+
<< dpasEncoding.getOpsPerChannel()
990+
<< " but tensor element type is " << elemTy
991+
<< ". Expected 32 bit type.";
992+
}
993+
return success();
994+
};
995+
996+
if (isa<DotOp>(op)) {
997+
auto dotOp = cast<DotOp>(op);
998+
auto aElemTy = dotOp.getA().getType().getElementType();
999+
auto result = validateDotDpasLayout(aElemTy);
1000+
if (result.failed())
1001+
return result;
1002+
1003+
auto bElemTy = dotOp.getB().getType().getElementType();
1004+
return validateDotDpasLayout(bElemTy);
1005+
}
1006+
}
1007+
1008+
return success();
1009+
}
1010+
};
1011+
9511012
//===----------------------------------------------------------------------===//
9521013

9531014
void TritonIntelGPUDialect::initialize() {
@@ -957,6 +1018,7 @@ void TritonIntelGPUDialect::initialize() {
9571018
>();
9581019

9591020
addInterfaces<TritonIntelGPUInferLayoutInterface>();
1021+
addInterfaces<TritonIntelGPUVerifyTensorLayoutInterface>();
9601022

9611023
addOperations<
9621024
#define GET_OP_LIST

0 commit comments

Comments
 (0)