Skip to content

Commit 398d69b

Browse files
committed
cleanup
1 parent 3630966 commit 398d69b

File tree

2 files changed

+8
-52
lines changed

2 files changed

+8
-52
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
222222
if (!isWgLayout())
223223
return failure();
224224

225+
// TODO: handle order attribute
225226
auto dims =
226227
llvm::map_to_vector(*getEffectiveSgLayout(), [&](int64_t d) -> Value {
227228
return builder.createOrFold<arith::ConstantIndexOp>(loc, d);

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 7 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -125,39 +125,6 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
125125
struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
126126
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
127127

128-
// Calculate offset for each subgroup
129-
static SmallVector<OpFoldResult>
130-
calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
131-
const SmallVector<OpFoldResult> &originalOffsets,
132-
const SmallVector<Value> &localOffset,
133-
const SmallVector<int64_t> &distUnitBaseAddr,
134-
const SmallVector<int64_t> &distUnitShape) {
135-
assert(localOffset.size() == distUnitBaseAddr.size() &&
136-
"localOffset and distUnitBaseAddr must have the same rank");
137-
138-
SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
139-
originalOffsets.end());
140-
size_t rank = localOffset.size();
141-
for (size_t i = 0; i < rank; ++i) {
142-
size_t dimIdx = originalOffsets.size() - rank + i;
143-
Value constOffset =
144-
arith::ConstantIndexOp::create(rewriter, loc, distUnitBaseAddr[i]);
145-
Value offset =
146-
rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
147-
Value modValue =
148-
arith::ConstantIndexOp::create(rewriter, loc, distUnitShape[i]);
149-
Value offsetMod =
150-
rewriter.createOrFold<index::RemUOp>(loc, offset, modValue);
151-
Value origOffset = getValueOrCreateConstantIndexOp(
152-
rewriter, loc, originalOffsets[dimIdx]);
153-
Value globalOffset =
154-
rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod);
155-
globalOffsets[dimIdx] = globalOffset;
156-
}
157-
158-
return globalOffsets;
159-
}
160-
161128
LogicalResult
162129
matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
163130
ConversionPatternRewriter &rewriter) const override {
@@ -177,28 +144,14 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
177144
return rewriter.notifyMatchFailure(
178145
op, "sgLayout attribute is required in layout");
179146

180-
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
181-
182-
// TODO : Handle order attribute
183147
// Get the subgroup ID
184-
auto linearSgId =
148+
Value linearSgId =
185149
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
186150

187-
// Create constants for layout dimensions
188-
SmallVector<Value> sgLayoutDim(sgLayout.size());
189-
SmallVector<Value> sgDataDim(sgShape.size());
190-
191-
for (size_t i = 0; i < sgLayout.size(); i++) {
192-
sgLayoutDim[i] =
193-
arith::ConstantIndexOp::create(rewriter, loc, sgLayout[i]);
194-
sgDataDim[i] = arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
195-
}
196-
197151
int64_t startOfRange = -1, endOfRange = -1;
198152
bool sgIdRangeSpecified =
199153
isSgIdRangeSpecified(op, startOfRange, endOfRange);
200154

201-
Value adjustedSgId = linearSgId;
202155
if (sgIdRangeSpecified) {
203156
int64_t sgCount = endOfRange - startOfRange;
204157
if (computeProduct(sgLayout) != sgCount)
@@ -208,22 +161,24 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
208161
// sg id
209162
Value startOfRangeVal =
210163
rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
211-
adjustedSgId =
164+
linearSgId =
212165
rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
213166
}
214167

215-
auto tdescOffsets = layout.getOffsets(rewriter, loc, adjustedSgId, wgShape);
216-
if (failed(tdescOffsets))
168+
auto maybeTdescOffsets =
169+
layout.getOffsets(rewriter, loc, linearSgId, wgShape);
170+
if (failed(maybeTdescOffsets))
217171
return failure();
218172

173+
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
219174
xegpu::TensorDescType newTdescTy =
220175
xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
221176
layout.dropSgLayoutAndData());
222177

223178
SmallVector<Value> newCreateNdOps;
224179
SmallVector<OpFoldResult> offset = op.getMixedOffsets();
225180

226-
for (auto tdescOffset : *tdescOffsets) {
181+
for (auto tdescOffset : *maybeTdescOffsets) {
227182
SmallVector<OpFoldResult> newOffsets = llvm::map_to_vector(
228183
llvm::zip_longest(tdescOffset, offset),
229184
[&](const auto &t) -> OpFoldResult {

0 commit comments

Comments
 (0)