@@ -68,6 +68,45 @@ gpu.module @test_kernel {
6868 }
6969}
7070
71+ // -----
72+ gpu.module @test_kernel {
73+ gpu.func @elementwise_with_inst_data_12 (%A: memref <1024 x1024 xf16 >, %B: memref <1024 x1024 xf16 >, %C: memref <1024 x1024 xf16 >) {
74+ %c0 = arith.constant 0 : index
75+ %c32 = arith.constant 32 : index
76+ %c1024 = arith.constant 1024 : index
77+ %block_id_x = gpu.block_id x
78+ %block_id_y = gpu.block_id y
79+ %m = arith.muli %block_id_x , %c32 : index
80+
81+ %a_tdesc = xegpu.create_nd_tdesc %A [%m , %c0 ] : memref <1024 x1024 xf16 > -> !xegpu.tensor_desc <12 x32 xf16 >
82+ %b_tdesc = xegpu.create_nd_tdesc %B [%m , %c0 ] : memref <1024 x1024 xf16 > -> !xegpu.tensor_desc <12 x32 xf16 >
83+ %c_tdesc = xegpu.create_nd_tdesc %C [%m , %c0 ] : memref <1024 x1024 xf16 > -> !xegpu.tensor_desc <12 x32 xf16 >
84+
85+ %out:3 = scf.for %k = %c0 to %c1024 step %c32
86+ iter_args (%arg0 = %a_tdesc , %arg1 = %b_tdesc , %arg2 = %c_tdesc )
87+ -> (!xegpu.tensor_desc <12 x32 xf16 >, !xegpu.tensor_desc <12 x32 xf16 >, !xegpu.tensor_desc <12 x32 xf16 >) {
88+ //CHECK: xegpu.load_nd {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>} :
89+ //CHECK-SAME: !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<12x32xf16>
90+ %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc <12 x32 xf16 > -> vector <12 x32 xf16 >
91+ %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc <12 x32 xf16 > -> vector <12 x32 xf16 >
92+
93+ //CHECK-COUNT: arith.addf {{.*}} {layout_result_0 = #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>} : vector<12x32xf16>
94+ %c = arith.addf %a , %b : vector <12 x32 xf16 >
95+
96+ //CHECK-COUNT: xegpu.store_nd {{.*}} : vector<12x32xf16>, !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>>>
97+ xegpu.store_nd %c , %arg2: vector <12 x32 xf16 >, !xegpu.tensor_desc <12 x32 xf16 >
98+
99+ //CHECK-COUNT: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<12x32xf16, #xegpu.layout<inst_data = [4, 16], lane_layout = [1, 16], lane_data = [1, 1]>>
100+ %a_next_tdesc = xegpu.update_nd_offset %arg0 , [%c0 , %c32 ] : !xegpu.tensor_desc <12 x32 xf16 >
101+ %b_next_tdesc = xegpu.update_nd_offset %arg1 , [%c0 , %c32 ] : !xegpu.tensor_desc <12 x32 xf16 >
102+ %c_next_tdesc = xegpu.update_nd_offset %arg2 , [%c0 , %c32 ] : !xegpu.tensor_desc <12 x32 xf16 >
103+ scf.yield %a_next_tdesc , %b_next_tdesc , %c_next_tdesc
104+ : !xegpu.tensor_desc <12 x32 xf16 >, !xegpu.tensor_desc <12 x32 xf16 >, !xegpu.tensor_desc <12 x32 xf16 >
105+ }
106+ gpu.return
107+ }
108+ }
109+
71110// -----
72111gpu.module @test {
73112// CHECK-LABEL: func.func @scatter_ops_chunksize(
0 commit comments