Skip to content

Commit 96da366

Browse files
committed
address comments
1 parent 8ea09e9 commit 96da366

File tree

2 files changed

+28
-18
lines changed

2 files changed

+28
-18
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ LogicalResult TensorDescType::verify(
323323

324324
// for gather and scatter ops, Low-precision types are packed in 32-bit units.
325325
unsigned bitWidth = elementType.getIntOrFloatBitWidth();
326-
int packingFactor =
326+
int chunkAlignmentFactor =
327327
bitWidth < targetinfo::packedSizeInBitsForGatherScatter
328328
? targetinfo::packedSizeInBitsForGatherScatter / bitWidth
329329
: 1;
@@ -342,10 +342,10 @@ LogicalResult TensorDescType::verify(
342342
if (chunkSize > 1) {
343343
if (shape.back() != chunkSize)
344344
return emitError() << "expected tensor shape[1] to match chunk size";
345-
if (shape.back() % packingFactor != 0)
346-
return emitError()
347-
<< "expected tensor shape[1] to be a multiple of packing factor "
348-
<< packingFactor;
345+
if (shape.back() % chunkAlignmentFactor != 0)
346+
return emitError() << "expected tensor shape[1] to be a multiple of "
347+
"chunk alignment factor "
348+
<< chunkAlignmentFactor;
349349
}
350350
}
351351

@@ -365,7 +365,7 @@ LogicalResult TensorDescType::verify(
365365
if (rank > 1 && laneData[0] != 1)
366366
return emitError()
367367
<< "cannot map over non-contiguous scattered row elements";
368-
if (laneData[rank - 1] != packingFactor)
368+
if (laneData[rank - 1] != chunkAlignmentFactor)
369369
return emitError() << "work item data mapping must match the number of "
370370
"contiguous elements";
371371
}

mlir/test/Dialect/XeGPU/propagate-layout.mlir

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,18 +90,28 @@ func.func @extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor
9090
}
9191

9292
// -----
93-
// CHECK-LABEL: func.func @load_gather_with_chunksize
94-
// CHECK-SAME: [[arg0:%.+]]: memref<256xf16>
95-
// CHECK: [[idx:%.+]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
96-
// CHECK: [[m:%.+]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
97-
// CHECK: [[desc:%.+]] = xegpu.create_tdesc [[arg0]], [[idx]] : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
98-
// CHECK: xegpu.load [[desc]], [[m]] : !xegpu.tensor_desc<16x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x8xf16>
99-
func.func @load_gather_with_chunksize(%arg0: memref<256xf16>) -> vector<16x8xf16> {
100-
%index = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
101-
%mask = arith.constant dense<true> : vector<16xi1>
102-
%1 = xegpu.create_tdesc %arg0, %index : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>
103-
%2 = xegpu.load %1, %mask : !xegpu.tensor_desc<16x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8 : i64>>, vector<16xi1> -> vector<16x8xf16>
104-
return %2: vector<16x8xf16>
93+
// CHECK-LABEL: func.func @load_gather_with_chunksize(
94+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<8x16xf16>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf16>, %[[ARG2:[0-9a-zA-Z]+]]: memref<8x16xf32>) {
95+
// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
96+
// CHECK-SAME: dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
97+
// CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
98+
// CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
99+
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
100+
// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
101+
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
102+
func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
103+
%c0 = arith.constant 0 : index
104+
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
105+
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
106+
%cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
107+
%cst_0 = arith.constant dense<true> : vector<16xi1>
108+
%2 = xegpu.create_tdesc %arg1, %cst : memref<256xf16>, vector<16xindex> -> !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>
109+
%3 = xegpu.load %2, %cst_0 : !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>>, vector<16xi1> -> vector<16x16xf16>
110+
%4 = vector.transpose %3, [1, 0] : vector<16x16xf16> to vector<16x16xf16>
111+
%5 = xegpu.dpas %1, %4 : vector<8x16xf16>, vector<16x16xf16> -> vector<8x16xf32>
112+
%6 = xegpu.create_nd_tdesc %arg2[%c0, %c0] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
113+
xegpu.store_nd %5, %6 : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
114+
return
105115
}
106116

107117
// -----

0 commit comments

Comments
 (0)