@@ -226,52 +226,3 @@ gpu.func @prefetch_1d(%arg0: memref<256xf16>){
226226 gpu.return
227227}
228228}
229-
230-
231- // -----
232- // CHECK-LABEL: gpu.func @gemm_loop
233- // CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<1024x1024xbf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<1024x1024xf32>) {
234- // CHECK: %[[BLOCK_ID_Y:.*]] = gpu.block_id y
235- // CHECK: %[[Y_COORD:.*]] = arith.muli %[[BLOCK_ID_Y]], %c16 : index
236- // CHECK: %[[BLOCK_ID_X:.*]] = gpu.block_id x
237- // CHECK: %[[X_COORD:.*]] = arith.muli %[[BLOCK_ID_X]], %c8 : index
238- // CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG2]][%[[X_COORD]], %[[Y_COORD]]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
239- // CHECK: %[[T3:.*]] = xegpu.load_nd %[[T2]] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
240- // CHECK: %[[T4:.*]] = vector.shape_cast %[[T3]] : vector<8xf32> to vector<8x1xf32>
241- // CHECK: %[[T5:.*]] = scf.for %[[K:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[T4]]) -> (vector<8x1xf32>) {
242- // CHECK: %[[T10:.*]] = xegpu.create_nd_tdesc %[[ARG1]][%[[K]], %[[Y_COORD]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<16x16xbf16>
243- // CHECK: %[[T11:.*]] = xegpu.load_nd %[[T10]] <{packed}> : !xegpu.tensor_desc<16x16xbf16> -> vector<16xbf16>
244- // CHECK: %[[T12:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[X_COORD]], %[[K]]] : memref<1024x1024xbf16> -> !xegpu.tensor_desc<8x16xbf16>
245- // CHECK: %[[T13:.*]] = xegpu.load_nd %[[T12]] : !xegpu.tensor_desc<8x16xbf16> -> vector<8xbf16>
246- // CHECK: %[[T14:.*]] = vector.shape_cast %[[ARG4]] : vector<8x1xf32> to vector<8xf32>
247- // CHECK: %[[T15:.*]] = xegpu.dpas %[[T13]], %[[T11]], %[[T14]] : vector<8xbf16>, vector<16xbf16>, vector<8xf32> -> vector<8xf32>
248- // CHECK: %[[T16:.*]] = vector.shape_cast %[[T15]] : vector<8xf32> to vector<8x1xf32>
249- // CHECK: scf.yield %[[T16]] : vector<8x1xf32>
250- // CHECK: }
251- // CHECK: %[[T8:.*]] = xegpu.create_nd_tdesc %[[ARG2]]{{.*}} : memref<1024x1024xf32> -> !xegpu.tensor_desc<8x16xf32>
252- // CHECK: %[[T9:.*]] = vector.shape_cast %[[T5]] : vector<8x1xf32> to vector<8xf32>
253- // CHECK: xegpu.store_nd %[[T9]], %[[T8]] : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
254- gpu.module @test {
255- gpu.func @gemm_loop (%arg0: memref <1024 x1024 xbf16 >, %arg1: memref <1024 x1024 xbf16 >, %arg2: memref <1024 x1024 xf32 >){
256- %c0 = arith.constant 0 : index
257- %c16 = arith.constant 16 : index
258- %c8 = arith.constant 8 : index
259- %c1024 = arith.constant 1024 : index
260- %0 = gpu.block_id x
261- %1 = gpu.block_id y
262- %2 = arith.muli %0 , %c8 : index
263- %3 = arith.muli %1 , %c16 : index
264- %4 = xegpu.create_nd_tdesc %arg2 [%2 , %3 ] : memref <1024 x1024 xf32 > -> !xegpu.tensor_desc <8 x16 xf32 >
265- %5 = xegpu.load_nd %4 : !xegpu.tensor_desc <8 x16 xf32 > -> vector <8 x16 xf32 >
266- %6 = scf.for %arg3 = %c0 to %c1024 step %c16 iter_args (%arg4 = %5 ) -> (vector <8 x16 xf32 >) {
267- %7 = xegpu.create_nd_tdesc %arg0 [%2 , %arg3 ] : memref <1024 x1024 xbf16 > -> !xegpu.tensor_desc <8 x16 xbf16 >
268- %8 = xegpu.create_nd_tdesc %arg1 [%arg3 , %3 ] : memref <1024 x1024 xbf16 > -> !xegpu.tensor_desc <16 x16 xbf16 >
269- %9 = xegpu.load_nd %7 : !xegpu.tensor_desc <8 x16 xbf16 > -> vector <8 x16 xbf16 >
270- %10 = xegpu.load_nd %8 : !xegpu.tensor_desc <16 x16 xbf16 > -> vector <16 x16 xbf16 >
271- %11 = xegpu.dpas %9 , %10 , %arg4 : vector <8 x16 xbf16 >, vector <16 x16 xbf16 >, vector <8 x16 xf32 > -> vector <8 x16 xf32 >
272- scf.yield %11 : vector <8 x16 xf32 >
273- }
274- xegpu.store_nd %6 , %4 : vector <8 x16 xf32 >, !xegpu.tensor_desc <8 x16 xf32 >
275- gpu.return
276- }
277- }
0 commit comments