@@ -82,6 +82,49 @@ gpu.module @test_kernel {
8282 }
8383}
8484
85+ // -----
86+ #l1 = #xegpu.layout <inst_data = [8 , 16 ]>
87+ #l2 = #xegpu.layout <inst_data = [16 , 16 ]>
88+ gpu.module @test_kernel {
89+ gpu.func @test_gemm (%A: memref <1024 x1024 xf16 >, %B: memref <1024 x1024 xf16 >, %C: memref <1024 x1024 xf32 >) {
90+ %c0 = arith.constant 0 : index
91+ %c8 = arith.constant 8 : index
92+ %c16 = arith.constant 16 : index
93+ %c32 = arith.constant 32 : index
94+ %c1024 = arith.constant 1024 : index
95+ %block_id_x = gpu.block_id x
96+ %block_id_y = gpu.block_id y
97+ %m = arith.muli %block_id_x , %c8 : index
98+ %n = arith.muli %block_id_y , %c32 : index
99+
100+ %c_tdesc = xegpu.create_nd_tdesc %C [%m , %n ] : memref <1024 x1024 xf32 > -> !xegpu.tensor_desc <8 x32 xf32 , #l1 >
101+
102+ //CHECK-COUNT-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
103+ %c_init = xegpu.load_nd %c_tdesc : !xegpu.tensor_desc <8 x32 xf32 , #l1 > -> vector <8 x32 xf32 >
104+
105+ %a_tdesc = xegpu.create_nd_tdesc %A [%m , %c0 ] : memref <1024 x1024 xf16 > -> !xegpu.tensor_desc <8 x16 xf16 , #l1 >
106+ %b_tdesc = xegpu.create_nd_tdesc %B [%c0 , %n ] : memref <1024 x1024 xf16 > -> !xegpu.tensor_desc <16 x32 xf16 , #l2 >
107+ %out:3 = scf.for %k = %c0 to %c1024 step %c16
108+ iter_args (%arg0 = %a_tdesc , %arg1 = %b_tdesc , %arg2 = %c_init )
109+ -> (!xegpu.tensor_desc <8 x16 xf16 , #l1 >, !xegpu.tensor_desc <16 x32 xf16 , #l2 >, vector <8 x32 xf32 >) {
110+ //CHECK: %22 = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
111+ %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc <8 x16 xf16 , #l1 > -> vector <8 x16 xf16 >
112+ //CHECK-COUNT-2: xegpu.load_nd {{.*}} : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
113+ %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc <16 x32 xf16 , #l2 > -> vector <16 x32 xf16 >
114+ %c = xegpu.dpas %a , %b , %arg2 {layout_result_0 = #l1 }: vector <8 x16 xf16 >, vector <16 x32 xf16 >, vector <8 x32 xf32 > -> vector <8 x32 xf32 >
115+ //CHECK: xegpu.update_nd_offset {{.*}} [%c0, %c32] : !xegpu.tensor_desc<8x16xf16>
116+ %a_next_tdesc = xegpu.update_nd_offset %arg0 , [%c0 , %c32 ] : !xegpu.tensor_desc <8 x16 xf16 , #l1 >
117+ //CHECK-COUNT-2: xegpu.update_nd_offset {{.*}} [%c32, %c0] : !xegpu.tensor_desc<16x16xf16>
118+ %b_next_tdesc = xegpu.update_nd_offset %arg1 , [%c32 , %c0 ] : !xegpu.tensor_desc <16 x32 xf16 , #l2 >
119+ scf.yield %a_next_tdesc , %b_next_tdesc , %c
120+ : !xegpu.tensor_desc <8 x16 xf16 , #l1 >, !xegpu.tensor_desc <16 x32 xf16 , #l2 >, vector <8 x32 xf32 >
121+ }
122+ //CHECK-COUNT-2: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
123+ xegpu.store_nd %out#2 , %c_tdesc: vector <8 x32 xf32 >, !xegpu.tensor_desc <8 x32 xf32 , #l1 >
124+ gpu.return
125+ }
126+ }
127+
85128// -----
86129#a = #xegpu.layout <inst_data = [8 , 16 ], lane_layout = [1 , 16 ], lane_data = [8 , 1 ]>
87130#b = #xegpu.layout <inst_data = [16 , 16 ], lane_layout = [1 , 16 ], lane_data = [16 , 1 ]>
0 commit comments