Skip to content

Commit d88d676

Browse files
committed
Update update_nd_tdesc.
1 parent d372592 commit d88d676

File tree

1 file changed

+19
-20
lines changed

1 file changed

+19
-20
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -248,27 +248,26 @@ class UpdateNdOffsetToXeVMPattern
248248
xegpu::UpdateNdOffsetOp::Adaptor adaptor,
249249
ConversionPatternRewriter &rewriter) const override {
250250
auto loc = op.getLoc();
251-
auto offsets = op.getOffsets();
251+
auto mixedOffsets = op.getMixedOffsets();
252+
if (mixedOffsets.size() != 2)
253+
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
252254
auto tdesc = adaptor.getTensorDesc();
253-
for (size_t offsetDim = 0; offsetDim < offsets.size(); offsetDim++) {
254-
auto offset = offsets[offsetDim];
255-
if (auto cst =
256-
dyn_cast_if_present<arith::ConstantOp>(offset.getDefiningOp()))
257-
if (auto attr = dyn_cast_if_present<IntegerAttr>(cst.getValue());
258-
attr && !attr.getInt())
259-
continue;
260-
const int offsetPos =
261-
static_cast<int>(offsetDim ? NdDescI32Layout::TensorOffsetW
262-
: NdDescI32Layout::TensorOffsetH);
263-
auto oldOffset =
264-
vector::ExtractOp::create(rewriter, loc, tdesc, offsetPos);
265-
offset = arith::IndexCastUIOp::create(rewriter, loc,
266-
rewriter.getI32Type(), offset);
267-
auto newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
268-
tdesc =
269-
vector::InsertOp::create(rewriter, loc, newOffset, tdesc, offsetPos);
270-
}
271-
rewriter.replaceOp(op, tdesc);
255+
// utility for updating payload offset values from op fold result.
256+
auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
257+
Value offset =
258+
getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]);
259+
offset = getValueOrCreateCastToIndexLike(rewriter, loc,
260+
rewriter.getI32Type(), offset);
261+
Value oldOffset =
262+
vector::ExtractOp::create(rewriter, loc, tdesc, payloadPos);
263+
Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
264+
return vector::InsertOp::create(rewriter, loc, newOffset, tdesc,
265+
payloadPos);
266+
};
267+
auto val =
268+
updateOffset(0, static_cast<int>(NdDescI32Layout::TensorOffsetH));
269+
val = updateOffset(1, static_cast<int>(NdDescI32Layout::TensorOffsetW));
270+
rewriter.replaceOp(op, val);
272271
return success();
273272
}
274273
};

0 commit comments

Comments
 (0)