@@ -183,11 +183,15 @@ static void adjustStridesForPermutation(AffineMap permMap,
183183// Computes memory strides and a memref offset for vector transfer operations,
184184// handling both static and dynamic memrefs while applying permutation
185185// transformations for XeGPU lowering.
186+ template <
187+ typename OpType,
188+ typename = std::enable_if_t <llvm::is_one_of<
189+ std::decay_t <OpType>, vector::TransferReadOp, vector::TransferWriteOp,
190+ vector::GatherOp, vector::ScatterOp>::value>>
186191static std::pair<SmallVector<Value>, Value>
187- computeMemrefMeta (VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
192+ computeMemrefMeta (OpType xferOp, PatternRewriter &rewriter) {
188193 SmallVector<Value> strides;
189194 Value baseMemref = xferOp.getBase ();
190- AffineMap permMap = xferOp.getPermutationMap ();
191195 MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType ());
192196
193197 Location loc = xferOp.getLoc ();
@@ -197,9 +201,14 @@ computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
197201 SmallVector<int64_t > intStrides;
198202 if (failed (memrefType.getStridesAndOffset (intStrides, offset)))
199203 return {{}, offsetVal};
200- // Wrap static strides as MLIR values
201- for (int64_t s : intStrides)
202- strides.push_back (arith::ConstantIndexOp::create (rewriter, loc, s));
204+ bool hasDynamicStrides = llvm::any_of (intStrides, [](int64_t strideVal) {
205+ return ShapedType::isDynamic (strideVal);
206+ });
207+
208+ if (!hasDynamicStrides)
209+ for (int64_t s : intStrides)
210+ strides.push_back (arith::ConstantIndexOp::create (rewriter, loc, s));
211+
203212 if (!ShapedType::isDynamic (offset))
204213 offsetVal = arith::ConstantIndexOp::create (rewriter, loc, offset);
205214 }
@@ -232,8 +241,14 @@ computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
232241 if (!offsetVal)
233242 offsetVal = meta.getOffset ();
234243 }
235- // Adjust strides according to the permutation map (e.g., for transpose)
236- adjustStridesForPermutation (permMap, strides);
244+
245+ if constexpr (llvm::is_one_of<std::decay_t <OpType>, vector::TransferReadOp,
246+ vector::TransferWriteOp>::value) {
247+ AffineMap permMap = xferOp.getPermutationMap ();
248+ // Adjust strides according to the permutation map (e.g., for transpose)
249+ adjustStridesForPermutation (permMap, strides);
250+ }
251+
237252 return {strides, offsetVal};
238253}
239254
@@ -339,9 +354,51 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
339354 return localOffsets;
340355}
341356
357+ // Compute the element-wise offsets for vector.gather or vector.scatter ops.
358+ //
359+ // This function linearizes the base offsets of the gather/scatter operation
360+ // and combines them with the per-element indices to produce a final vector of
361+ // memory offsets.
362+ template <
363+ typename OpType,
364+ typename = std::enable_if_t <llvm::is_one_of<
365+ std::decay_t <OpType>, vector::GatherOp, vector::ScatterOp>::value>>
366+ static Value computeOffsets (PatternRewriter &rewriter, OpType gatScatOp,
367+ ArrayRef<Value> strides, Value baseOffset) {
368+ Location loc = gatScatOp.getLoc ();
369+ SmallVector<Value> offsets = gatScatOp.getOffsets ();
370+ for (size_t i = 0 ; i < offsets.size (); ++i) {
371+ Value offsetContrib =
372+ arith::MulIOp::create (rewriter, loc, offsets[i], strides[i]);
373+ baseOffset =
374+ arith::AddIOp::create (rewriter, loc, baseOffset, offsetContrib);
375+ }
376+ Value indices = gatScatOp.getIndices ();
377+ VectorType vecType = cast<VectorType>(indices.getType ());
378+
379+ Value strideVector =
380+ vector::BroadcastOp::create (rewriter, loc, vecType, strides.back ())
381+ .getResult ();
382+ Value stridedIndices =
383+ arith::MulIOp::create (rewriter, loc, strideVector, indices).getResult ();
384+
385+ Value baseVector =
386+ vector::BroadcastOp::create (
387+ rewriter, loc,
388+ VectorType::get (vecType.getShape (), rewriter.getIndexType ()),
389+ baseOffset)
390+ .getResult ();
391+ return arith::AddIOp::create (rewriter, loc, baseVector, stridedIndices)
392+ .getResult ();
393+ }
394+
395+ template <
396+ typename OpType,
397+ typename = std::enable_if_t <llvm::is_one_of<
398+ std::decay_t <OpType>, vector::TransferReadOp, vector::TransferWriteOp,
399+ vector::GatherOp, vector::ScatterOp>::value>>
342400// Convert memref to i64 base pointer
343- static Value memrefToIndexPtr (VectorTransferOpInterface xferOp,
344- PatternRewriter &rewriter) {
401+ static Value memrefToIndexPtr (OpType xferOp, PatternRewriter &rewriter) {
345402 Location loc = xferOp.getLoc ();
346403 auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create (
347404 rewriter, loc, xferOp.getBase ())
@@ -539,6 +596,71 @@ struct TransferWriteLowering
539596 }
540597};
541598
599+ struct GatherLowering : public OpRewritePattern <vector::GatherOp> {
600+ using OpRewritePattern<vector::GatherOp>::OpRewritePattern;
601+
602+ LogicalResult matchAndRewrite (vector::GatherOp gatherOp,
603+ PatternRewriter &rewriter) const override {
604+ auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase ().getType ());
605+ if (!srcTy)
606+ return rewriter.notifyMatchFailure (gatherOp, " Expects memref source" );
607+
608+ Location loc = gatherOp.getLoc ();
609+ VectorType vectorType = gatherOp.getVectorType ();
610+
611+ auto meta = computeMemrefMeta (gatherOp, rewriter);
612+ if (meta.first .empty ())
613+ return rewriter.notifyMatchFailure (gatherOp, " Failed to compute strides" );
614+
615+ Value localOffsets =
616+ computeOffsets (rewriter, gatherOp, meta.first , meta.second );
617+ Value flatMemref = memrefToIndexPtr (gatherOp, rewriter);
618+
619+ auto xeGatherOp = xegpu::LoadGatherOp::create (
620+ rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask (),
621+ /* chunk_size=*/ IntegerAttr{},
622+ /* l1_hint=*/ xegpu::CachePolicyAttr{},
623+ /* l2_hint=*/ xegpu::CachePolicyAttr{},
624+ /* l3_hint=*/ xegpu::CachePolicyAttr{});
625+
626+ auto selectOp =
627+ arith::SelectOp::create (rewriter, loc, gatherOp.getMask (),
628+ xeGatherOp.getResult (), gatherOp.getPassThru ());
629+ rewriter.replaceOp (gatherOp, selectOp.getResult ());
630+ return success ();
631+ }
632+ };
633+
634+ struct ScatterLowering : public OpRewritePattern <vector::ScatterOp> {
635+ using OpRewritePattern<vector::ScatterOp>::OpRewritePattern;
636+
637+ LogicalResult matchAndRewrite (vector::ScatterOp scatterOp,
638+ PatternRewriter &rewriter) const override {
639+ auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase ().getType ());
640+ if (!srcTy)
641+ return rewriter.notifyMatchFailure (scatterOp, " Expects memref source" );
642+
643+ Location loc = scatterOp.getLoc ();
644+ auto meta = computeMemrefMeta (scatterOp, rewriter);
645+ if (meta.first .empty ())
646+ return rewriter.notifyMatchFailure (scatterOp,
647+ " Failed to compute strides" );
648+
649+ Value localOffsets =
650+ computeOffsets (rewriter, scatterOp, meta.first , meta.second );
651+ Value flatMemref = memrefToIndexPtr (scatterOp, rewriter);
652+
653+ xegpu::StoreScatterOp::create (rewriter, loc, scatterOp.getValueToStore (),
654+ flatMemref, localOffsets, scatterOp.getMask (),
655+ /* chunk_size=*/ IntegerAttr{},
656+ /* l1_hint=*/ xegpu::CachePolicyAttr{},
657+ /* l2_hint=*/ xegpu::CachePolicyAttr{},
658+ /* l3_hint=*/ xegpu::CachePolicyAttr{});
659+ rewriter.eraseOp (scatterOp);
660+ return success ();
661+ }
662+ };
663+
542664struct LoadLowering : public OpRewritePattern <vector::LoadOp> {
543665 using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
544666
@@ -654,6 +776,8 @@ struct ConvertVectorToXeGPUPass
654776
655777void mlir::populateVectorToXeGPUConversionPatterns (
656778 RewritePatternSet &patterns) {
657- patterns.add <TransferReadLowering, TransferWriteLowering, LoadLowering,
658- StoreLowering, ContractionLowering>(patterns.getContext ());
779+ patterns
780+ .add <TransferReadLowering, TransferWriteLowering, LoadLowering,
781+ ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
782+ patterns.getContext ());
659783}
0 commit comments