Skip to content

Commit ce7d549

Browse files
committed
cleaner tests
Signed-off-by: dchigarev <[email protected]>
1 parent 049beb9 commit ce7d549

File tree

2 files changed

+4
-11
lines changed

2 files changed

+4
-11
lines changed

test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-broadcast-fold.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ func.func @complex_broadcast_3d() {
7777
%3 = memref.alloc() : memref<7x7x128xf16>
7878

7979
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c1, %arg16 = %c1) {
80+
// This broadcast can't be replaced by a single memref.subview. Can't remove it
8081
// CHECK: linalg.broadcast
8182
linalg.broadcast ins(%0 : memref<7x128xf16>) outs(%1 : memref<7x7x128xf16>) dimensions = [0]
8283
linalg.add ins(%1, %2 : memref<7x7x128xf16>, memref<7x7x128xf16>) outs(%3 : memref<7x7x128xf16>)
@@ -99,9 +100,10 @@ func.func @single_broadcast() {
99100
%1 = memref.alloc() : memref<1x1x7x128xf16>
100101

101102
gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c1, %arg16 = %c1) {
103+
// broadcast result is not an input of any xegpu operation, we can't lower it
102104
// CHECK: linalg.broadcast
103105
linalg.broadcast ins(%0 : memref<7x128xf16>) outs(%1 : memref<1x1x7x128xf16>) dimensions = [0, 1]
104106
gpu.terminator
105107
}
106108
return
107-
}
109+
}

test/mlir/test/gc/gpu-runner/XeGPU/rope_sanity_test.mlir

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,20 @@ func.func @RoPE(%iinput: !input_memref_type, %ipos_ids: !pos_ids_memref_type, %i
2626
%pos_ids = bufferization.to_tensor %ipos_ids restrict : !pos_ids_memref_type
2727
%pos_id_end = bufferization.to_tensor %ipos_id_end restrict : memref<1xindex>
2828
%3 = tensor.empty(): !output_tensor_type
29-
//call @stopTimerMy() : () -> ()
29+
3030
%transpose_in = linalg.transpose ins(%input: !input_tensor_type) outs(%3:!output_tensor_type) permutation = [0, 2, 1, 3]
3131

32-
//call @startTimerMy() : () -> ()
3332
%c0 = arith.constant 0 : index
3433
%c3 = arith.constant 3 : index
3534
%cos_cache_slice = tensor.extract_slice %cos_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type
3635
%cos_cache_slice2 = tensor.collapse_shape %cos_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype>
3736
%cos_cache_slice3 = tensor.collapse_shape %cos_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype>
3837
%pos_ids_index=tensor.expand_shape %pos_ids [[0],[1,2]] output_shape [1, 7, 1] : tensor<1x7xindex> into tensor<1x7x1xindex>
39-
//call @stopTimerMy() : () -> ()
40-
41-
//call @startTimerMy() : () -> ()
4238

4339
%cos_cache_slice4 = tensor.gather %cos_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype>
4440

45-
//call @stopTimerMy() : () -> ()
46-
47-
//call @startTimerMy() : () -> ()
48-
// %cos_cache_slice4 = tensor.expand_shape %cos_cache_slice3[[0,1],[2]] output_shape [1,7,128] : tensor<7x128x!dtype> into tensor<1x7x128x!dtype>
4941
%cos_cache_slice5 = tensor.expand_shape %cos_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype>
5042
%cos_cache_slice6 = tensor.collapse_shape %cos_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype>
51-
//call @stopTimerMy() : () -> ()
5243

5344
%cos_cache_slice7 = linalg.broadcast ins(%cos_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1]
5445
%input_apply_cos_cache = linalg.mul ins(%transpose_in, %cos_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type

0 commit comments

Comments
 (0)