Skip to content

Commit a092bd3

Browse files
committed
Address Feedback
1 parent 4612f64 commit a092bd3

File tree

2 files changed

+73
-12
lines changed

2 files changed

+73
-12
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
8181
calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
8282
const SmallVector<OpFoldResult> &originalOffsets,
8383
const SmallVector<Value> &localOffset,
84-
const SmallVector<int64_t> &distUnitBaseAddr) const {
84+
const SmallVector<int64_t> &distUnitBaseAddr,
85+
const SmallVector<int64_t> &distUnitShape) const {
8586
assert(localOffset.size() == distUnitBaseAddr.size() &&
8687
"localOffset and distUnitBaseAddr must have the same rank");
8788

@@ -105,9 +106,13 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
105106
rewriter.create<arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]);
106107
Value offset =
107108
rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
109+
Value modValue =
110+
rewriter.create<arith::ConstantIndexOp>(loc, distUnitShape[i]);
111+
Value offsetMod =
112+
rewriter.createOrFold<index::RemUOp>(loc, offset, modValue);
108113
Value origOffset = getValueFromOpFoldResult(originalOffsets[dimIdx]);
109114
Value globalOffset =
110-
rewriter.createOrFold<index::AddOp>(loc, origOffset, offset);
115+
rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod);
111116
globalOffsets[dimIdx] = globalOffset;
112117
}
113118

@@ -125,10 +130,27 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
125130
return failure();
126131
Type elemTy = tdescTy.getElementType();
127132
ArrayRef<int64_t> wgShape = tdescTy.getShape();
128-
SmallVector<int64_t> sgShape =
129-
llvm::to_vector_of<int64_t>(layout.getSgData().asArrayRef());
130-
SmallVector<int64_t> sgLayout =
131-
llvm::to_vector_of<int64_t>(layout.getSgLayout().asArrayRef());
133+
SmallVector<int64_t> sgLayout;
134+
if (auto sgLayoutAttr = layout.getSgLayout()) {
135+
sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
136+
} else {
137+
// sgLayout must be present for workgroup-level distribution.
138+
op.emitError("sgLayout attribute is required in layout");
139+
return failure();
140+
}
141+
142+
SmallVector<int64_t> sgShape;
143+
if (auto sgDataAttr = layout.getSgData()) {
144+
sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
145+
} else {
146+
assert(wgShape.size() == sgLayout.size() &&
147+
"sgLayout and wgShape must have the same rank");
148+
sgShape.reserve(wgShape.size());
149+
for (size_t i = 0; i < wgShape.size(); ++i) {
150+
assert(sgLayout[i] != 0 && "sgLayout elements must be non-zero");
151+
sgShape.push_back(wgShape[i] / sgLayout[i]);
152+
}
153+
}
132154

133155
// TODO : Handle order attribute
134156
// Get the subgroup ID
@@ -168,8 +190,9 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
168190
SmallVector<Value> newCreateNdOps;
169191
for (SmallVector<int64_t> distUnitBaseAddr :
170192
StaticTileOffsetRange(wgShape, distUnitShape)) {
171-
SmallVector<OpFoldResult> globalOffsets = calculateGlobalOffsets(
172-
rewriter, loc, originalOffsets, localOffset, distUnitBaseAddr);
193+
SmallVector<OpFoldResult> globalOffsets =
194+
calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset,
195+
distUnitBaseAddr, distUnitShape);
173196

174197
auto newCreateNdOp = rewriter.create<xegpu::CreateNdDescOp>(
175198
loc, newTdescTy, op.getSource(), globalOffsets, op.getMixedSizes(),
@@ -258,11 +281,10 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
258281
if (!originalLayout)
259282
return failure();
260283

261-
size_t i = 0;
262284
SmallVector<Value> newDpasOps;
285+
size_t i = 0;
263286
for (auto aVec : adaptor.getLhs()) {
264287
for (auto bVec : adaptor.getRhs()) {
265-
266288
llvm::SmallVector<Value> operands({aVec, bVec});
267289
Value tmpC;
268290
if (op.getAcc()) {

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,14 @@ gpu.module @test_1_1_assignment {
1414
// CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
1515
// CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C12]]
1616
// CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C8]]
17+
// CHECK: %[[C24:.*]] = arith.constant 24 : index
18+
// CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C24]]
1719
// CHECK: %[[C0:.*]] = arith.constant 0 : index
18-
// CHECK: %[[ADD1:.*]] = index.add %[[MUL1]], %[[C0]]
19-
// CHECK: %[[ADD2:.*]] = index.add %[[MUL2]], %[[C0]]
20+
// CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0]]
21+
// CHECK: %[[C32:.*]] = arith.constant 32 : index
22+
// CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C32]]
23+
// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
24+
// CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_1]]
2025
// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32>
2126
// CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
2227
// CHECK: gpu.return
@@ -108,6 +113,40 @@ gpu.func @test_dpas(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
108113
gpu.return
109114
}
110115

116+
117+
// CHECK-LABEL: test_dpas_no_sg_data
118+
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
119+
// CHECK-SAME: %[[ARG_1:.*]]: memref<32x24xf32>
120+
gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
121+
// CHECK: %[[TDESC_A:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
122+
// CHECk-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
123+
// CHECK: %[[LOAD_A:.*]] = xegpu.load_nd %[[TDESC_A]]
124+
// CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
125+
// CHECK-SAME: -> vector<12x8xf32>
126+
// CHECK: %[[TDESC_B:.*]] = xegpu.create_nd_tdesc %[[ARG_1]][{{%.*}}, {{%.*}}] : memref<32x24xf32>
127+
// CHECK-SAME: -> !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
128+
// CHECK: %[[LOAD_B:.*]] = xegpu.load_nd %[[TDESC_B]]
129+
// CHECK-SAME: : !xegpu.tensor_desc<8x12xf32, #xegpu.layout<lane_layout = [8, 2], lane_data = [1, 1]>>
130+
// CHECK-SAME: -> vector<8x12xf32>
131+
// CHECK: %[[DPAS:.*]] = xegpu.dpas %[[LOAD_A]], %[[LOAD_B]]
132+
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
133+
// CHECK-SAME: : vector<12x8xf32>, vector<8x12xf32> -> vector<12x12xf32>
134+
%tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<24x32xf32>
135+
-> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>>
136+
%load_a = xegpu.load_nd %tdesc_a
137+
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], lane_layout = [2, 8], lane_data = [1, 1]>>
138+
-> vector<24x32xf32>
139+
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<32x24xf32>
140+
-> !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>>
141+
%load_b = xegpu.load_nd %tdesc_b
142+
: !xegpu.tensor_desc<32x24xf32, #xegpu.layout<sg_layout = [4, 2], lane_layout = [8, 2], lane_data = [1, 1]>>
143+
-> vector<32x24xf32>
144+
%dpas = xegpu.dpas %load_a, %load_b
145+
{layout = #xegpu.layout<sg_layout = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
146+
: vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
147+
gpu.return
148+
}
149+
111150
// CHECK-LABEL: test_prefetch_nd_tdesc
112151
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
113152
gpu.func @test_prefetch_nd_tdesc(%src: memref<24x32xf32>) {

0 commit comments

Comments
 (0)