diff --git a/test/TritonIntelGPU/tritonintelgpu-invalid.mlir b/test/TritonIntelGPU/tritonintelgpu-invalid.mlir index 75d391460b..f95ebd9384 100644 --- a/test/TritonIntelGPU/tritonintelgpu-invalid.mlir +++ b/test/TritonIntelGPU/tritonintelgpu-invalid.mlir @@ -158,3 +158,59 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr tt.return %res : tensor<8x16xf16> } } + + +// ----- + +// COM: A operand mismatch +#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: matmul_tf32dot + tt.func @matmul_tf32dot(%ptr:!tt.ptr, + %a_mat:tensor<32x16xf32, #dot_operand_a>, %b_mat:tensor<16x32xf32, #dot_operand_b>) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas> + + // expected-error @+1 {{Layout has opsPerChannel = 2 but tensor element type is 'f32'. Expected 16 bit type.}} + %28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas> + + tt.return + } +} + +// ----- + +// COM: B operand mismatch +#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}> +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: matmul_tf32dot + tt.func @matmul_tf32dot(%ptr:!tt.ptr, + %a_mat:tensor<32x16xf16, #dot_operand_a>, %b_mat:tensor<16x32xf32, #dot_operand_b>) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas> + + // expected-error @+1 {{Layout has opsPerChannel = 2 but tensor element type is 'f32'. Expected 16 bit type.}} + %28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf16, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas> + + tt.return + } +} + +// ----- + +#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}> +#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}> +// expected-error @below {{ttg.dot_op kWidth parameter must match the parent's opsPerChannel}} +#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: matmul_tf32dot + tt.func @matmul_tf32dot(%ptr:!tt.ptr, + %a_mat:tensor<32x16xf32, #dot_operand_a>, %b_mat:tensor<16x32xf32, #dot_operand_b>) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas> + %28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas> + + tt.return + } +} diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index 138fccf6c0..a5d120ccc6 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -1203,6 +1203,57 @@ struct TritonIntelGPUInferLayoutInterface } }; +struct TritonIntelGPUVerifyTensorLayoutInterface + : public triton::DialectVerifyTensorLayoutInterface { + using DialectVerifyTensorLayoutInterface::DialectVerifyTensorLayoutInterface; + + LogicalResult verifyTensorLayout( + Attribute layout, RankedTensorType rankedTy, Operation *op, + function_ref makeErr) const override { + + // Verify that the DPAS layout opsPerChannel param matches the A and B + // operand types. Because the DotOperand layout is not part of the Triton + // Intel GPU dialect, we need to first check for a tt.dot operation. Then, + // we can compare the type of each operand to the Dot operation with the + // DPAS layout attached to the Dot operation. + if (auto dpasEncoding = dyn_cast(layout)) { + + auto validateDotDpasLayout = [&](Type elemTy) -> LogicalResult { + if (auto ptrTy = dyn_cast(elemTy)) { + elemTy = ptrTy.getPointeeType(); + } + const unsigned elemTyBitWidth = elemTy.getIntOrFloatBitWidth(); + + // We know opsPerChannel is either 1, 4, or 8 because of the DPAS + // verifier when the DPAS attribute is created. Here we verify that + // opsPerChannel matches the tensor type. + if (dpasEncoding.getOpsPerChannel() * elemTyBitWidth != 32) { + return makeErr() << layout << ".\nLayout has opsPerChannel = " + << dpasEncoding.getOpsPerChannel() + << " but tensor element type is " << elemTy + << ". Expected " + << 32 / dpasEncoding.getOpsPerChannel() + << " bit type."; + } + return success(); + }; + + if (auto dotOp = dyn_cast(op)) { + auto aElemTy = dotOp.getA().getType().getElementType(); + auto bElemTy = dotOp.getB().getType().getElementType(); + + auto aResult = validateDotDpasLayout(aElemTy); + if (aResult.failed()) + return aResult; + + return validateDotDpasLayout(bElemTy); + } + } + + return success(); + } +}; + //===----------------------------------------------------------------------===// void TritonIntelGPUDialect::initialize() { @@ -1212,6 +1263,7 @@ void TritonIntelGPUDialect::initialize() { >(); addInterfaces(); + addInterfaces(); addOperations< #define GET_OP_LIST