@@ -170,9 +170,9 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
170170 gpu.return
171171 }
172172
173- // CHECK-LABEL: broadcast
173+ // CHECK-LABEL: broadcast_dim1
174174 // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
175- gpu.func @broadcast (%src: memref <24 x1 xf32 >) {
175+ gpu.func @broadcast_dim1 (%src: memref <24 x1 xf32 >) {
176176 %tdesc = xegpu.create_nd_tdesc %src [0 , 0 ] : memref <24 x1 xf32 >
177177 -> !xegpu.tensor_desc <24 x1 xf32 , #xegpu.layout <sg_layout = [2 , 1 ], sg_data = [12 , 1 ], lane_layout = [2 , 1 ], lane_data = [1 , 1 ]>>
178178 %load = xegpu.load_nd %tdesc
@@ -186,6 +186,22 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
186186 gpu.return
187187 }
188188
189+ // CHECK-LABEL: broadcast_dim0
190+ // CHECK-SAME: %[[ARG_0:.*]]: memref<1x32xf32>
191+ gpu.func @broadcast_dim0 (%src: memref <1 x32 xf32 >) {
192+ %tdesc = xegpu.create_nd_tdesc %src [0 , 0 ] : memref <1 x32 xf32 >
193+ -> !xegpu.tensor_desc <1 x32 xf32 , #xegpu.layout <sg_layout = [1 , 4 ], sg_data = [1 , 8 ], lane_layout = [1 , 8 ], lane_data = [1 , 1 ]>>
194+ %load = xegpu.load_nd %tdesc
195+ : !xegpu.tensor_desc <1 x32 xf32 , #xegpu.layout <sg_layout = [1 , 4 ], sg_data = [1 , 8 ], lane_layout = [1 , 8 ], lane_data = [1 , 1 ]>>
196+ -> vector <1 x32 xf32 >
197+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 8], lane_data = [1, 1]>}
198+ // CHECK-SAME: : vector<1x8xf32> to vector<12x8xf32>
199+ %broadcast = vector.broadcast %load
200+ {layout_result_0 = #xegpu.layout <sg_layout = [1 , 4 ], sg_data = [12 , 8 ], lane_layout = [1 , 8 ], lane_data = [1 , 1 ]>}
201+ : vector <1 x32 xf32 > to vector <12 x32 xf32 >
202+ gpu.return
203+ }
204+
189205 gpu.func @scf_for (%arg0: memref <1024 x1024 xf16 >, %arg1: memref <1024 x1024 xf16 >, %arg2: memref <1024 x1024 xf32 >) {
190206 //CHECK: [[c0:%.+]] = arith.constant 0 : index
191207 //CHECK: [[c128:%.+]] = arith.constant 128 : index
0 commit comments