@@ -408,23 +408,20 @@ gpu.module @test_distribution {
408408 }
409409
410410 // CHECK-LABEL: vector_shape_cast
411- gpu.func @vector_shape_cast (%src: memref <256 x128 xf32 >) {
412- %tdesc = xegpu.create_nd_tdesc %src : memref <256 x128 xf32 >
413- -> !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
414- %load = xegpu.load_nd %tdesc [0 , 0 ]
415- : !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
416- -> vector <256 x128 xf32 >
417- //CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<32x1x32x1xf32>
418- %cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout <sg_layout = [8 , 1 , 4 , 1 ], sg_data = [32 , 1 , 32 , 1 ]>} : vector <256 x128 xf32 > to vector <256 x1 x128 x1 xf32 >
411+ gpu.func @vector_shape_cast () {
412+ %cst = arith.constant {layout_result_0 = #xegpu.slice <#xegpu.layout <sg_layout = [8 , 1 , 1 , 4 ], sg_data = [1 , 1 , 1 , 32 ]>, dims = [0 , 1 , 2 ]>} dense <10 > : vector <128 xindex >
413+ %step = vector.step {layout_result_0 = #xegpu.slice <#xegpu.layout <sg_layout = [8 , 1 , 1 , 4 ], sg_data = [1 , 1 , 1 , 32 ]>, dims = [0 , 1 , 2 ]>} : vector <128 xindex >
414+ %muli = arith.muli %cst , %step {layout_result_0 = #xegpu.slice <#xegpu.layout <sg_layout = [8 , 1 , 1 , 4 ], sg_data = [1 , 1 , 1 , 32 ]>, dims = [0 , 1 , 2 ]>} : vector <128 xindex >
415+ //CHECK: vector.shape_cast {{.*}} : vector<32xindex> to vector<1x1x1x32xindex>
416+ %shape_cast = vector.shape_cast %muli {layout_result_0 = #xegpu.layout <sg_layout = [8 , 1 , 1 , 4 ], sg_data = [1 , 1 , 1 , 32 ]>} : vector <128 xindex > to vector <1 x1 x1 x128 xindex >
419417 gpu.return
420418 }
421419
422- // CHECK-LABEL: broadcast
423- // CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index
424- gpu.func @broadcast (%arg0: index , %arg1: index ) {
425- %muli = arith.muli %arg0 , %arg1 : index
426- // CHECK: vector.broadcast {{.*}} : index to vector<1x1x1x32xindex>
427- %broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout <sg_layout = [4 , 2 , 6 , 1 ], sg_data = [1 , 1 , 1 , 32 ]>} : index to vector <4 x2 x6 x32 xindex >
428- gpu.return
429- }
420+ // CHECK-LABEL: vector_broadcast
421+ gpu.func @vector_broadcast (%arg0: index , %arg1: index ) {
422+ %muli = arith.muli %arg0 , %arg1 : index
423+ // CHECK: vector.broadcast {{.*}} : index to vector<1x1x1x32xindex>
424+ %broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout <sg_layout = [4 , 2 , 6 , 1 ], sg_data = [1 , 1 , 1 , 32 ]>} : index to vector <4 x2 x6 x32 xindex >
425+ gpu.return
426+ }
430427}
0 commit comments