Skip to content

Commit 341daff

Browse files
committed
fix test
1 parent 71902aa commit 341daff

File tree

1 file changed

+40
-44
lines changed

1 file changed

+40
-44
lines changed

mlir/test/Dialect/XeGPU/subgroup-distribute.mlir

Lines changed: 40 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -168,52 +168,48 @@ gpu.module @test {
168168
// -----
169169
// CHECK-LABEL: gpu.func @gemm_loop
170170
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
171-
// CHECK: %[[BLOCK_ID_X:.*]] = gpu.block_id x
172-
// CHECK: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
173-
// CHECK: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
174-
// CHECK: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
171+
// CHECK-DAG: %[[BLOCK_ID_X:.*]] = gpu.block_id x
172+
// CHECK-DAG: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
173+
// CHECK-DAG: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
174+
// CHECK-DAG: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
175175
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%[[X_COORD]], %[[Y_COORD]]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
176-
// CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
177-
// CHECK-DAG: %[[C_INIT:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
178-
// CHECK-DAG: %[[B_TILE:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%{{.*}}, %[[Y_COORD]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
179-
// CHECK-DAG: %[[A_TILE:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[X_COORD]], %{{.*}}] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
180-
// CHECK: %[[T7:.*]]:3 = scf.for {{.*}} iter_args(%[[C_VAL:.*]] = %[[C_INIT]], %[[A_ARG:.*]] = %[[A_TILE]], %[[B_ARG:.*]] = %[[B_TILE]]) -> (vector<8x1xf32>, !xegpu.tensor_desc<8x16xbf16>, !xegpu.tensor_desc<16x16xbf16>) {
181-
// CHECK-DAG: %[[B_NEXT:.*]] = xegpu.update_nd_offset %[[B_ARG]], [{{.*}}] : !xegpu.tensor_desc<16x16xbf16>
182-
// CHECK-DAG: %[[A_NEXT:.*]] = xegpu.update_nd_offset %[[A_ARG]], [{{.*}}] : !xegpu.tensor_desc<8x16xbf16>
183-
// CHECK-DAG: %[[B:.*]] = xegpu.load_nd %[[B_ARG]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
184-
// CHECK-DAG: %[[A:.*]] = xegpu.load_nd %[[A_ARG]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
185-
// CHECK-DAG: %[[C:.*]] = vector.shape_cast %[[C_VAL]] : vector<8x1xf32> to vector<8xf32>
186-
// CHECK-NEXT: %[[T8:.*]] = xegpu.dpas %[[A]], %[[B]], %[[C]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
187-
// CHECK-NEXT: %[[C_OUT:.*]] = vector.shape_cast %[[T8]] : vector<8xf32> to vector<8x1xf32>
188-
// CHECK-NEXT: scf.yield %[[C_OUT]], %[[A_NEXT]], %[[B_NEXT]] : vector<8x1xf32>, !xegpu.tensor_desc<8x16xbf16>, !xegpu.tensor_desc<16x16xbf16>
189-
// CHECK-NEXT:}
190-
// CHECK-NEXT: %[[C_FINAL:.*]] = vector.shape_cast %[[T7]]#0 : vector<8x1xf32> to vector<8xf32>
191-
// CHECK-NEXT: xegpu.store_nd %[[C_FINAL]], %[[T2]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
176+
// CHECK-NEXT: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
177+
// CHECK-NEXT: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
178+
// CHECK: %[[T5:.*]] = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[T4]]) -> (vector<8x1xf32>) {
179+
// CHECK-DAG: %[[T10:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%[[K]], %[[Y_COORD]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
180+
// CHECK-DAG: %[[T11:.*]] = xegpu.load_nd %[[T10]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
181+
// CHECK-DAG: %[[T12:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[X_COORD]], %[[K]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
182+
// CHECK-DAG: %[[T13:.*]] = xegpu.load_nd %[[T12]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
183+
// CHECK-DAG: %[[T14:.*]] = vector.shape_cast %[[ARG4]] : vector<8x1xf32> to vector<8xf32>
184+
// CHECK-NEXT: %[[T15:.*]] = xegpu.dpas %[[T13]], %[[T11]], %[[T14]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
185+
// CHECK-NEXT: %[[T16:.*]] = vector.shape_cast %[[T15]] : vector<8xf32> to vector<8x1xf32>
186+
// CHECK-NEXT: scf.yield %[[T16]] : vector<8x1xf32>
187+
// CHECK-NEXT: }
188+
// CHECK-NEXT: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
189+
// CHECK-NEXT: xegpu.store_nd %[[T9]], %[[T2]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
192190
gpu.module @test {
193-
gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>) {
194-
%c0 = arith.constant 0 : index
195-
%c16 = arith.constant 16 : index
196-
%c8 = arith.constant 8 : index
197-
%c1024 = arith.constant 1024 : index
198-
%block_id_x = gpu.block_id x
199-
%block_id_y = gpu.block_id y
200-
%0 = arith.muli %block_id_x, %c8 : index
201-
%1 = arith.muli %block_id_y, %c16 : index
202-
%2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
203-
%3 = xegpu.load_nd %2 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf32>
204-
%4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
205-
%5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
206-
%6:3 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %3, %arg5 = %4, %arg6 = %5) -> (vector<8x16xf32>, !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>) {
207-
%8 = xegpu.load_nd %arg5 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xbf16>
208-
%9 = xegpu.load_nd %arg6 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xbf16>
209-
%10 = xegpu.update_nd_offset %arg5, [%c0, %c16] : !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
210-
%11 = xegpu.update_nd_offset %arg6, [%c16, %c0] : !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
211-
%12 = xegpu.dpas %8, %9, %arg4 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
212-
scf.yield {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} %12, %10, %11 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>, !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
213-
} {layout_operand_3 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
214-
xegpu.store_nd %6#0, %2 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
215-
gpu.return
216-
}
191+
gpu.func @gemm_loop(%arg0: memref<1024x1024xbf16>, %arg1: memref<1024x1024xbf16>, %arg2: memref<1024x1024xf32>){
192+
%c0 = arith.constant 0 : index
193+
%c16 = arith.constant 16 : index
194+
%c8 = arith.constant 8 : index
195+
%c1024 = arith.constant 1024 : index
196+
%block_id_x = gpu.block_id x
197+
%block_id_y = gpu.block_id y
198+
%0 = arith.muli %block_id_x, %c8 : index
199+
%1 = arith.muli %block_id_y, %c16 : index
200+
%2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
201+
%3 = xegpu.load_nd %2 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xf32>
202+
%4 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args(%arg4 = %3) -> (vector<8x16xf32>) {
203+
%5 = xegpu.create_nd_tdesc %arg0[%0, %arg3] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
204+
%6 = xegpu.create_nd_tdesc %arg1[%arg3, %1] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>>
205+
%7 = xegpu.load_nd %5 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<8x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<8x16xbf16>
206+
%8 = xegpu.load_nd %6 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>} : !xegpu.tensor_desc<16x16xbf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>> -> vector<16x16xbf16>
207+
%9 = xegpu.dpas %7, %8, %arg4 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xbf16>, vector<16x16xbf16>, vector<8x16xf32> -> vector<8x16xf32>
208+
scf.yield {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} %9 : vector<8x16xf32>
209+
} {layout_operand_3 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
210+
xegpu.store_nd %4, %2 {layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
211+
gpu.return
212+
}
217213
}
218214

219215
// -----

0 commit comments

Comments
 (0)