Skip to content

Commit 9e60885

Browse files
committed
Revert semantic changes to tests and add new ones
1 parent 99abe0f commit 9e60885

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

test/Triton/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ tt.func public @fn(%arg0: tensor<128xf32>, %arg1: tensor<64xf32>) {
3838

3939
tt.func public @reshape_different_num_elements(%arg0: tensor<32x128xf16>) {
4040
// expected-error @+1 {{number of src and dst elements of reshape must be the same}}
41-
%a = tt.reshape %arg0 allow_reorder : tensor<32x128xf16> -> tensor<64x32xf16>
41+
%a = tt.reshape %arg0 : tensor<32x128xf16> -> tensor<64x32xf16>
4242
tt.return
4343
}
4444

test/Triton/ops.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,11 +225,32 @@ tt.func @inline_asm_scalar(%0: i32) {
225225

226226
// CHECK-LABEL: reshape
227227
tt.func @reshape(%0: tensor<512xi32>) {
228+
// CHECK: tt.reshape %{{.+}} : tensor<512xi32> -> tensor<16x32xi32>
229+
%1 = tt.reshape %0 : tensor<512xi32> -> tensor<16x32xi32>
230+
tt.return
231+
}
232+
233+
// CHECK-LABEL: reshape_allow_reorder
234+
tt.func @reshape_allow_reorder(%0: tensor<512xi32>) {
228235
// CHECK: tt.reshape %{{.+}} allow_reorder : tensor<512xi32> -> tensor<16x32xi32>
229236
%1 = tt.reshape %0 allow_reorder : tensor<512xi32> -> tensor<16x32xi32>
230237
tt.return
231238
}
232239

240+
// CHECK-LABEL: reshape_efficient_layout
241+
tt.func @reshape_efficient_layout(%0: tensor<512xi32>) {
242+
// CHECK: tt.reshape %{{.+}} efficient_layout : tensor<512xi32> -> tensor<16x32xi32>
243+
%1 = tt.reshape %0 efficient_layout : tensor<512xi32> -> tensor<16x32xi32>
244+
tt.return
245+
}
246+
247+
// CHECK-LABEL: reshape_allow_reorder_efficient_layout
248+
tt.func @reshape_allow_reorder_efficient_layout(%0: tensor<512xi32>) {
249+
// CHECK: tt.reshape %{{.+}} allow_reorder efficient_layout : tensor<512xi32> -> tensor<16x32xi32>
250+
%1 = tt.reshape %0 allow_reorder efficient_layout : tensor<512xi32> -> tensor<16x32xi32>
251+
tt.return
252+
}
253+
233254
// CHECK-LABEL: histogram
234255
tt.func @histogram(%0: tensor<512xi32>) {
235256
// CHECK: tt.histogram %{{.+}} : tensor<512xi32> -> tensor<16xi32>

test/TritonGPU/loop-pipeline.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1341,7 +1341,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 :
13411341
%84 = arith.sitofp %82 : tensor<64x256xi8, #blocked> to tensor<64x256xf16, #blocked>
13421342
%85 = tt.join %83, %84 : tensor<64x256xf16, #blocked> -> tensor<64x256x2xf16, #blocked3>
13431343
%86 = tt.trans %85 {order = array<i32: 0, 2, 1>} : tensor<64x256x2xf16, #blocked3> -> tensor<64x2x256xf16, #blocked4>
1344-
%87 = tt.reshape %86 allow_reorder : tensor<64x2x256xf16, #blocked4> -> tensor<128x256xf16, #blocked5>
1344+
%87 = tt.reshape %86 : tensor<64x2x256xf16, #blocked4> -> tensor<128x256xf16, #blocked5>
13451345
%88 = triton_gpu.convert_layout %78 : tensor<16x128xf16, #blocked1> -> tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
13461346
%89 = triton_gpu.convert_layout %87 : tensor<128x256xf16, #blocked5> -> tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
13471347
%90 = tt.dot %88, %89, %arg10 : tensor<16x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x256xf32, #mma>

0 commit comments

Comments
 (0)