@@ -90,18 +90,28 @@ func.func @extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor
9090}
9191
9292// -----
93- // CHECK-LABEL: func.func @load_gather_with_chunksize
94- // CHECK-SAME: [[arg0:%.+]]: memref<256xf16>
95- // CHECK: [[idx:%.+]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
96- // CHECK: [[m:%.+]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
97- // CHECK: [[desc:%.+]] = xegpu.create_tdesc [[arg0]], [[idx]] : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
98- // CHECK: xegpu.load [[desc]], [[m]] : !xegpu.tensor_desc<16x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x8xf16>
99- func.func @load_gather_with_chunksize (%arg0: memref <256 xf16 >) -> vector <16 x8 xf16 > {
100- %index = arith.constant dense <[0 , 16 , 32 , 48 , 64 , 80 , 96 , 112 , 128 , 144 , 160 , 176 , 192 , 208 , 224 , 240 ]> : vector <16 xindex >
101- %mask = arith.constant dense <true > : vector <16 xi1 >
102- %1 = xegpu.create_tdesc %arg0 , %index : memref <256 xf16 >, vector <16 xindex > -> !xegpu.tensor_desc <16 x8 xf16 , #xegpu.scatter_tdesc_attr <chunk_size = 8 : i64 >>
103- %2 = xegpu.load %1 , %mask : !xegpu.tensor_desc <16 x8 xf16 , #xegpu.scatter_tdesc_attr <chunk_size = 8 : i64 >>, vector <16 xi1 > -> vector <16 x8 xf16 >
104- return %2: vector <16 x8 xf16 >
93+ // CHECK-LABEL: func.func @load_gather_with_chunksize(
94+ // CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
95+ // CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
96+ // CHECK-SAME: dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
97+ // CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
98+ // CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
99+ // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
100+ // CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
101+ // CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
102+ func.func @load_gather_with_chunksize (%arg0: memref <8 x16 xf16 >, %arg1: memref <256 xf16 >, %arg2: memref <8 x16 xf32 >) {
103+ %c0 = arith.constant 0 : index
104+ %0 = xegpu.create_nd_tdesc %arg0 [%c0 , %c0 ] : memref <8 x16 xf16 > -> !xegpu.tensor_desc <8 x16 xf16 >
105+ %1 = xegpu.load_nd %0 : !xegpu.tensor_desc <8 x16 xf16 > -> vector <8 x16 xf16 >
106+ %cst = arith.constant dense <[0 , 16 , 32 , 48 , 64 , 80 , 96 , 112 , 128 , 144 , 160 , 176 , 192 , 208 , 224 , 240 ]> : vector <16 xindex >
107+ %cst_0 = arith.constant dense <true > : vector <16 xi1 >
108+ %2 = xegpu.create_tdesc %arg1 , %cst : memref <256 xf16 >, vector <16 xindex > -> !xegpu.tensor_desc <16 x16 xf16 , #xegpu.scatter_tdesc_attr <chunk_size = 16 : i64 >>
109+ %3 = xegpu.load %2 , %cst_0 : !xegpu.tensor_desc <16 x16 xf16 , #xegpu.scatter_tdesc_attr <chunk_size = 16 : i64 >>, vector <16 xi1 > -> vector <16 x16 xf16 >
110+ %4 = vector.transpose %3 , [1 , 0 ] : vector <16 x16 xf16 > to vector <16 x16 xf16 >
111+ %5 = xegpu.dpas %1 , %4 : vector <8 x16 xf16 >, vector <16 x16 xf16 > -> vector <8 x16 xf32 >
112+ %6 = xegpu.create_nd_tdesc %arg2 [%c0 , %c0 ] : memref <8 x16 xf32 > -> !xegpu.tensor_desc <8 x16 xf32 >
113+ xegpu.store_nd %5 , %6 : vector <8 x16 xf32 >, !xegpu.tensor_desc <8 x16 xf32 >
114+ return
105115}
106116
107117// -----
0 commit comments