Skip to content

Commit 1d17537

Browse files
committed
Add test case for dim0
1 parent 9d71167 commit 1d17537

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,9 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
170170
gpu.return
171171
}
172172

173-
// CHECK-LABEL: broadcast
173+
// CHECK-LABEL: broadcast_dim1
174174
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
175-
gpu.func @broadcast(%src: memref<24x1xf32>) {
175+
gpu.func @broadcast_dim1(%src: memref<24x1xf32>) {
176176
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
177177
-> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
178178
%load = xegpu.load_nd %tdesc
@@ -186,6 +186,22 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
186186
gpu.return
187187
}
188188

189+
// CHECK-LABEL: broadcast_dim0
190+
// CHECK-SAME: %[[ARG_0:.*]]: memref<1x32xf32>
191+
gpu.func @broadcast_dim0(%src: memref<1x32xf32>) {
192+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x32xf32>
193+
-> !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
194+
%load = xegpu.load_nd %tdesc
195+
: !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
196+
-> vector<1x32xf32>
197+
// CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 8], lane_data = [1, 1]>}
198+
// CHECK-SAME: : vector<1x8xf32> to vector<12x8xf32>
199+
%broadcast = vector.broadcast %load
200+
{layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [12, 8], lane_layout = [1, 8], lane_data = [1, 1]>}
201+
: vector<1x32xf32> to vector<12x32xf32>
202+
gpu.return
203+
}
204+
189205
gpu.func @scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
190206
//CHECK: [[c0:%.+]] = arith.constant 0 : index
191207
//CHECK: [[c128:%.+]] = arith.constant 128 : index

0 commit comments

Comments
 (0)