Skip to content

Commit 08e4aa9

Browse files
committed
fix a bug
1 parent 398d69b commit 08e4aa9

File tree

1 file changed

+12
-18
lines changed

1 file changed

+12
-18
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -179,26 +179,20 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
179179
SmallVector<OpFoldResult> offset = op.getMixedOffsets();
180180

181181
for (auto tdescOffset : *maybeTdescOffsets) {
182-
SmallVector<OpFoldResult> newOffsets = llvm::map_to_vector(
183-
llvm::zip_longest(tdescOffset, offset),
184-
[&](const auto &t) -> OpFoldResult {
185-
std::optional<Value> off = std::get<0>(t);
186-
std::optional<OpFoldResult> old = std::get<1>(t);
187-
if (!off.has_value())
188-
return *old;
189-
190-
if (!old.has_value() || isZeroInteger(*old))
191-
return *off;
192-
193-
return rewriter.createOrFold<index::AddOp>(
194-
loc, *off,
195-
getValueOrCreateConstantIndexOp(rewriter, loc, *old));
196-
});
197-
198-
auto newCreateNdOp = xegpu::CreateNdDescOp::create(
182+
SmallVector<OpFoldResult> newOffsets;
183+
size_t rank = tdescOffset.size();
184+
for (size_t i = 0; i < rank; i++) {
185+
size_t idx = offset.size() - rank + i;
186+
Value newOff = rewriter.createOrFold<index::AddOp>(
187+
loc, tdescOffset[i],
188+
getValueOrCreateConstantIndexOp(rewriter, loc, offset[idx]));
189+
newOffsets.push_back(newOff);
190+
}
191+
192+
auto newOp = xegpu::CreateNdDescOp::create(
199193
rewriter, loc, newTdescTy, op.getSource(), newOffsets,
200194
op.getMixedSizes(), op.getMixedStrides());
201-
newCreateNdOps.push_back(newCreateNdOp);
195+
newCreateNdOps.push_back(newOp);
202196
}
203197
rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
204198
return success();

0 commit comments

Comments
 (0)