Skip to content

Commit b8da87e

Browse files
committed
Use getMixedOffsets
1 parent 46686f5 commit b8da87e

File tree

1 file changed

+20
-23
lines changed

1 file changed

+20
-23
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -70,23 +70,6 @@ namespace {
7070
struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
7171
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
7272

73-
// Helper to extract mixed offsets into a Value array
74-
SmallVector<Value> extractOffsets(ConversionPatternRewriter &rewriter,
75-
xegpu::CreateNdDescOp op) const {
76-
llvm::SmallVector<Value> offsets;
77-
auto staticOffsets = op.getStaticOffsets();
78-
auto dynamicOffsets = op.getOffsets();
79-
80-
for (size_t i = 0, j = 0; i != staticOffsets.size(); i++) {
81-
if (ShapedType::isDynamic(staticOffsets[i]))
82-
offsets.push_back(dynamicOffsets[j++]);
83-
else
84-
offsets.push_back(rewriter.create<arith::ConstantIndexOp>(
85-
op.getLoc(), staticOffsets[i]));
86-
}
87-
return offsets;
88-
}
89-
9073
// Convert linear subgroup ID to 2D coordinates
9174
// TODO: Delinearize for nD
9275
SmallVector<Value> delinearizeSubgroupId(ConversionPatternRewriter &rewriter,
@@ -99,7 +82,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
9982
// Calculate offset for each subgroup
10083
SmallVector<OpFoldResult>
10184
calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
102-
const SmallVector<Value> &originalOffsets,
85+
const SmallVector<OpFoldResult> &originalOffsets,
10386
const SmallVector<Value> &localOffset,
10487
const SmallVector<int64_t> &distUnitBaseAddr) const {
10588

@@ -116,10 +99,24 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
11699
size_t lastDimIndex = originalOffsets.size() - 1;
117100
size_t secondLastDimIndex = lastDimIndex - 1;
118101

119-
Value globalOffsetX = rewriter.createOrFold<index::AddOp>(
120-
loc, originalOffsets[secondLastDimIndex], offsetX);
121-
Value globalOffsetY = rewriter.createOrFold<index::AddOp>(
122-
loc, originalOffsets[lastDimIndex], offsetY);
102+
// Convert originalOffsets to Value
103+
auto getValueFromOpFoldResult = [&](OpFoldResult ofr) -> Value {
104+
if (auto val = ofr.dyn_cast<Value>())
105+
return val;
106+
if (auto attr = ofr.dyn_cast<Attribute>()) {
107+
int64_t staticOffset = cast<IntegerAttr>(attr).getInt();
108+
return rewriter.create<arith::ConstantIndexOp>(loc, staticOffset);
109+
}
110+
llvm_unreachable("Unsupported OpFoldResult kind");
111+
};
112+
113+
Value origOffsetX =
114+
getValueFromOpFoldResult(originalOffsets[secondLastDimIndex]);
115+
Value origOffsetY = getValueFromOpFoldResult(originalOffsets[lastDimIndex]);
116+
Value globalOffsetX =
117+
rewriter.createOrFold<index::AddOp>(loc, origOffsetX, offsetX);
118+
Value globalOffsetY =
119+
rewriter.createOrFold<index::AddOp>(loc, origOffsetY, offsetY);
123120

124121
SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
125122
originalOffsets.end());
@@ -172,7 +169,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
172169
rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
173170
}
174171

175-
SmallVector<Value> originalOffsets = extractOffsets(rewriter, op);
172+
SmallVector<OpFoldResult> originalOffsets = op.getMixedOffsets();
176173

177174
xegpu::TensorDescType newTdescTy =
178175
xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),

0 commit comments

Comments
 (0)