@@ -248,27 +248,26 @@ class UpdateNdOffsetToXeVMPattern
248248 xegpu::UpdateNdOffsetOp::Adaptor adaptor,
249249 ConversionPatternRewriter &rewriter) const override {
250250 auto loc = op.getLoc ();
251- auto offsets = op.getOffsets ();
251+ auto mixedOffsets = op.getMixedOffsets ();
252+ if (mixedOffsets.size () != 2 )
253+ return rewriter.notifyMatchFailure (op, " Expected 2D offsets." );
252254 auto tdesc = adaptor.getTensorDesc ();
253- for (size_t offsetDim = 0 ; offsetDim < offsets.size (); offsetDim++) {
254- auto offset = offsets[offsetDim];
255- if (auto cst =
256- dyn_cast_if_present<arith::ConstantOp>(offset.getDefiningOp ()))
257- if (auto attr = dyn_cast_if_present<IntegerAttr>(cst.getValue ());
258- attr && !attr.getInt ())
259- continue ;
260- const int offsetPos =
261- static_cast <int >(offsetDim ? NdDescI32Layout::TensorOffsetW
262- : NdDescI32Layout::TensorOffsetH);
263- auto oldOffset =
264- vector::ExtractOp::create (rewriter, loc, tdesc, offsetPos);
265- offset = arith::IndexCastUIOp::create (rewriter, loc,
266- rewriter.getI32Type (), offset);
267- auto newOffset = arith::AddIOp::create (rewriter, loc, oldOffset, offset);
268- tdesc =
269- vector::InsertOp::create (rewriter, loc, newOffset, tdesc, offsetPos);
270- }
271- rewriter.replaceOp (op, tdesc);
255+ // utility for updating payload offset values from op fold result.
256+ auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
257+ Value offset =
258+ getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[idx]);
259+ offset = getValueOrCreateCastToIndexLike (rewriter, loc,
260+ rewriter.getI32Type (), offset);
261+ Value oldOffset =
262+ vector::ExtractOp::create (rewriter, loc, tdesc, payloadPos);
263+ Value newOffset = arith::AddIOp::create (rewriter, loc, oldOffset, offset);
264+ return vector::InsertOp::create (rewriter, loc, newOffset, tdesc,
265+ payloadPos);
266+ };
267+ auto val =
268+ updateOffset (0 , static_cast <int >(NdDescI32Layout::TensorOffsetH));
269+ val = updateOffset (1 , static_cast <int >(NdDescI32Layout::TensorOffsetW));
270+ rewriter.replaceOp (op, val);
272271 return success ();
273272 }
274273};
0 commit comments