@@ -168,52 +168,48 @@ gpu.module @test {
168168// -----
169169// CHECK-LABEL: gpu.func @gemm_loop
170170// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
171- // CHECK: %[[BLOCK_ID_X:.*]] = gpu.block_id x
172- // CHECK: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
173- // CHECK: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
174- // CHECK: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
171+ // CHECK-DAG : %[[BLOCK_ID_X:.*]] = gpu.block_id x
172+ // CHECK-DAG : %[[BLOCK_ID_Y:.*]] = gpu.block_id y
173+ // CHECK-DAG : %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
174+ // CHECK-DAG : %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
175175// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%[[X_COORD]], %[[Y_COORD]]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
176- // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
177- // CHECK-DAG: %[[C_INIT:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
178- // CHECK-DAG: %[[B_TILE:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}, %[[Y_COORD]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
179- // CHECK-DAG: %[[A_TILE:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[X_COORD]], %{{.*}}] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
180- // CHECK: %[[T7:.*]]:3 = scf.for {{.*}} iter_args(%[[C_VAL:.*]] = %[[C_INIT]], %[[A_ARG:.*]] = %[[A_TILE]], %[[B_ARG:.*]] = %[[B_TILE]]) -> (vector<8x1xf32>, !xegpu.tensor_desc<8x16xbf16>, !xegpu.tensor_desc<16x16xbf16>) {
181- // CHECK-DAG: %[[B_NEXT:.*]] = xegpu.update_nd_offset %[[B_ARG]], [{{.*}}] : !xegpu.tensor_desc<16x16xbf16>
182- // CHECK-DAG: %[[A_NEXT:.*]] = xegpu.update_nd_offset %[[A_ARG]], [{{.*}}] : !xegpu.tensor_desc<8x16xbf16>
183- // CHECK-DAG: %[[B:.*]] = xegpu.load_nd %[[B_ARG]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
184- // CHECK-DAG: %[[A:.*]] = xegpu.load_nd %[[A_ARG]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
185- // CHECK-DAG: %[[C:.*]] = vector.shape_cast %[[C_VAL]] : vector<8x1xf32> to vector<8xf32>
186- // CHECK-NEXT: %[[T8:.*]] = xegpu.dpas %[[A]], %[[B]], %[[C]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
187- // CHECK-NEXT: %[[C_OUT:.*]] = vector.shape_cast %[[T8]] : vector<8xf32> to vector<8x1xf32>
188- // CHECK-NEXT: scf.yield %[[C_OUT]], %[[A_NEXT]], %[[B_NEXT]] : vector<8x1xf32>, !xegpu.tensor_desc<8x16xbf16>, !xegpu.tensor_desc<16x16xbf16>
189- // CHECK-NEXT:}
190- // CHECK-NEXT: %[[C_FINAL:.*]] = vector.shape_cast %[[T7]]#0 : vector<8x1xf32> to vector<8xf32>
191- // CHECK-NEXT: xegpu.store_nd %[[C_FINAL]], %[[T2]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
176+ // CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
177+ // CHECK-NEXT: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
178+ // CHECK: %[[T5:.*]] = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[T4]]) -> (vector<8x1xf32>) {
179+ // CHECK-DAG: %[[T10:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%[[K]], %[[Y_COORD]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
180+ // CHECK-DAG: %[[T11:.*]] = xegpu.load_nd %[[T10]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
181+ // CHECK-DAG: %[[T12:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[X_COORD]], %[[K]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
182+ // CHECK-DAG: %[[T13:.*]] = xegpu.load_nd %[[T12]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
183+ // CHECK-DAG: %[[T14:.*]] = vector.shape_cast %[[ARG4]] : vector<8x1xf32> to vector<8xf32>
184+ // CHECK-NEXT: %[[T15:.*]] = xegpu.dpas %[[T13]], %[[T11]], %[[T14]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
185+ // CHECK-NEXT: %[[T16:.*]] = vector.shape_cast %[[T15]] : vector<8xf32> to vector<8x1xf32>
186+ // CHECK-NEXT: scf.yield %[[T16]] : vector<8x1xf32>
187+ // CHECK-NEXT: }
188+ // CHECK-NEXT: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
189+ // CHECK-NEXT: xegpu.store_nd %[[T9]], %[[T2]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
192190gpu.module @test {
193- gpu.func @gemm_loop (%arg0: memref <1024 x1024 xbf16 >, %arg1: memref <1024 x1024 xbf16 >, %arg2: memref <1024 x1024 xf32 >) {
194- %c0 = arith.constant 0 : index
195- %c16 = arith.constant 16 : index
196- %c8 = arith.constant 8 : index
197- %c1024 = arith.constant 1024 : index
198- %block_id_x = gpu.block_id x
199- %block_id_y = gpu.block_id y
200- %0 = arith.muli %block_id_x , %c8 : index
201- %1 = arith.muli %block_id_y , %c16 : index
202- %2 = xegpu.create_nd_tdesc %arg2 [%0 , %1 ] : memref <1024 x1024 xf32 > -> !xegpu.tensor_desc <8 x16 xf32 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
203- %3 = xegpu.load_nd %2 {layout_result_0 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>} : !xegpu.tensor_desc <8 x16 xf32 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>> -> vector <8 x16 xf32 >
204- %4 = xegpu.create_nd_tdesc %arg0 [%0 , %c0 ] : memref <1024 x1024 xbf16 > -> !xegpu.tensor_desc <8 x16 xbf16 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
205- %5 = xegpu.create_nd_tdesc %arg1 [%c0 , %1 ] : memref <1024 x1024 xbf16 > -> !xegpu.tensor_desc <16 x16 xbf16 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [2 , 1 ]>>
206- %6:3 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args (%arg4 = %3 , %arg5 = %4 , %arg6 = %5 ) -> (vector <8 x16 xf32 >, !xegpu.tensor_desc <8 x16 xbf16 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>, !xegpu.tensor_desc <16 x16 xbf16 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [2 , 1 ]>>) {
207- %8 = xegpu.load_nd %arg5 {layout_result_0 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>} : !xegpu.tensor_desc <8 x16 xbf16 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>> -> vector <8 x16 xbf16 >
208- %9 = xegpu.load_nd %arg6 {layout_result_0 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [2 , 1 ]>} : !xegpu.tensor_desc <16 x16 xbf16 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [2 , 1 ]>> -> vector <16 x16 xbf16 >
209- %10 = xegpu.update_nd_offset %arg5 , [%c0 , %c16 ] : !xegpu.tensor_desc <8 x16 xbf16 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
210- %11 = xegpu.update_nd_offset %arg6 , [%c16 , %c0 ] : !xegpu.tensor_desc <16 x16 xbf16 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [2 , 1 ]>>
211- %12 = xegpu.dpas %8 , %9 , %arg4 {layout_operand_0 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>, layout_operand_1 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [2 , 1 ]>, layout_result_0 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>} : vector <8 x16 xbf16 >, vector <16 x16 xbf16 >, vector <8 x16 xf32 > -> vector <8 x16 xf32 >
212- scf.yield {layout_operand_0 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>} %12 , %10 , %11 : vector <8 x16 xf32 >, !xegpu.tensor_desc <8 x16 xbf16 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>, !xegpu.tensor_desc <16 x16 xbf16 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [2 , 1 ]>>
213- } {layout_operand_3 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>, layout_result_0 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>}
214- xegpu.store_nd %6#0 , %2 {layout_operand_0 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>} : vector <8 x16 xf32 >, !xegpu.tensor_desc <8 x16 xf32 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
215- gpu.return
216- }
191+ gpu.func @gemm_loop (%arg0: memref <1024 x1024 xbf16 >, %arg1: memref <1024 x1024 xbf16 >, %arg2: memref <1024 x1024 xf32 >){
192+ %c0 = arith.constant 0 : index
193+ %c16 = arith.constant 16 : index
194+ %c8 = arith.constant 8 : index
195+ %c1024 = arith.constant 1024 : index
196+ %block_id_x = gpu.block_id x
197+ %block_id_y = gpu.block_id y
198+ %0 = arith.muli %block_id_x , %c8 : index
199+ %1 = arith.muli %block_id_y , %c16 : index
200+ %2 = xegpu.create_nd_tdesc %arg2 [%0 , %1 ] : memref <1024 x1024 xf32 > -> !xegpu.tensor_desc <8 x16 xf32 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
201+ %3 = xegpu.load_nd %2 {layout_result_0 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>} : !xegpu.tensor_desc <8 x16 xf32 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>> -> vector <8 x16 xf32 >
202+ %4 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args (%arg4 = %3 ) -> (vector <8 x16 xf32 >) {
203+ %5 = xegpu.create_nd_tdesc %arg0 [%0 , %arg3 ] : memref <1024 x1024 xbf16 > -> !xegpu.tensor_desc <8 x16 xbf16 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
204+ %6 = xegpu.create_nd_tdesc %arg1 [%arg3 , %1 ] : memref <1024 x1024 xbf16 > -> !xegpu.tensor_desc <16 x16 xbf16 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [2 , 1 ]>>
205+ %7 = xegpu.load_nd %5 {layout_result_0 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>} : !xegpu.tensor_desc <8 x16 xbf16 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>> -> vector <8 x16 xbf16 >
206+ %8 = xegpu.load_nd %6 {layout_result_0 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [2 , 1 ]>} : !xegpu.tensor_desc <16 x16 xbf16 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [2 , 1 ]>> -> vector <16 x16 xbf16 >
207+ %9 = xegpu.dpas %7 , %8 , %arg4 {layout_operand_0 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>, layout_operand_1 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [2 , 1 ]>, layout_result_0 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>} : vector <8 x16 xbf16 >, vector <16 x16 xbf16 >, vector <8 x16 xf32 > -> vector <8 x16 xf32 >
208+ scf.yield {layout_operand_0 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>} %9 : vector <8 x16 xf32 >
209+ } {layout_operand_3 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>, layout_result_0 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>}
210+ xegpu.store_nd %4 , %2 {layout_operand_0 = #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>} : vector <8 x16 xf32 >, !xegpu.tensor_desc <8 x16 xf32 , #xegpu.layout <lane_layout = [1 , 16 ], lane_data = [1 , 1 ]>>
211+ gpu.return
212+ }
217213}
218214
219215// -----
0 commit comments