Skip to content

Commit c38da38

Browse files
committed
add B operand mismatch test
1 parent cf664ef commit c38da38

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

test/TritonIntelGPU/tritonintelgpu-invalid.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr
162162

163163
// -----
164164

165+
// COM: A operand mismatch
165166
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}>
166167
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}>
167168
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
@@ -181,6 +182,26 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
181182

182183
// -----
183184

185+
// COM: B operand mismatch
186+
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}>
187+
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}>
188+
#dot_operand_b = #ttg.dot_op<{opIdx=1, parent=#dpas, kWidth=2}>
189+
#smem = #ttg.shared_memory
190+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
191+
// CHECK-LABEL: matmul_tf32dot
192+
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32>,
193+
%a_mat:tensor<32x16xf16, #dot_operand_a>, %b_mat:tensor<16x32xf32, #dot_operand_b>) {
194+
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #dpas>
195+
196+
// expected-error @+1 {{Layout has opsPerChannel = 2 but tensor element type is 'f32'. Expected 16 bit type.}}
197+
%28 = tt.dot %a_mat, %b_mat, %cst, inputPrecision = tf32 : tensor<32x16xf16, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #dpas>
198+
199+
tt.return
200+
}
201+
}
202+
203+
// -----
204+
184205
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [1, 1]}>
185206
#dot_operand_a = #ttg.dot_op<{opIdx=0, parent=#dpas, kWidth=1}>
186207
// expected-error @below {{ttg.dot_op kWidth parameter must match the parent's opsPerChannel}}

0 commit comments

Comments
 (0)