Skip to content

Commit c4617bc

Browse files
authored
[MLIR][XeGPU][VectorToXeGPU] Add lowering from vector.gather/scatter to xegpu.load/store (llvm#158024)
Lowering for `vector.gather`/`vector.scatter` into `xegpu.load`/`xegpu.store`. High level steps to lower vector.gather/scatter: ``` %0 = vector.gather %source[%off1, %off2, %off3][%indices], %mask, %pass_thru : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32> ``` 1. Compute strides and a memref offset for the `%source` memref using `computeMemrefMeta` func from the transfer_read/write lowering 2. Compute a linear offset like `%lin_off = %base_offset + %off1 * strides#0 + %off2 * strides#1 + %off3 * strides#2` 3. Combine the linear offset with `%indices`: `%off = (broadcast %lin_off : index to vector<8xindex>) + %indices * strides#2` 4. Convert memref to an i64: `%flat_memref = memref.extract_aligned_pointer_as_index %source + arith.index_cast` 5. Perform load/store: `%vec = xegpu.load %flat_memref[%off], %mask` 6. Apply selection to propagate values from the pass_thru vector: `%res = arith.select %mask, %vec, %pass_thru`
1 parent 771c94c commit c4617bc

File tree

3 files changed

+592
-11
lines changed

3 files changed

+592
-11
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 135 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,15 @@ static void adjustStridesForPermutation(AffineMap permMap,
183183
// Computes memory strides and a memref offset for vector transfer operations,
184184
// handling both static and dynamic memrefs while applying permutation
185185
// transformations for XeGPU lowering.
186+
template <
187+
typename OpType,
188+
typename = std::enable_if_t<llvm::is_one_of<
189+
std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
190+
vector::GatherOp, vector::ScatterOp>::value>>
186191
static std::pair<SmallVector<Value>, Value>
187-
computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
192+
computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) {
188193
SmallVector<Value> strides;
189194
Value baseMemref = xferOp.getBase();
190-
AffineMap permMap = xferOp.getPermutationMap();
191195
MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
192196

193197
Location loc = xferOp.getLoc();
@@ -197,9 +201,14 @@ computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
197201
SmallVector<int64_t> intStrides;
198202
if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
199203
return {{}, offsetVal};
200-
// Wrap static strides as MLIR values
201-
for (int64_t s : intStrides)
202-
strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
204+
bool hasDynamicStrides = llvm::any_of(intStrides, [](int64_t strideVal) {
205+
return ShapedType::isDynamic(strideVal);
206+
});
207+
208+
if (!hasDynamicStrides)
209+
for (int64_t s : intStrides)
210+
strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
211+
203212
if (!ShapedType::isDynamic(offset))
204213
offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
205214
}
@@ -232,8 +241,14 @@ computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
232241
if (!offsetVal)
233242
offsetVal = meta.getOffset();
234243
}
235-
// Adjust strides according to the permutation map (e.g., for transpose)
236-
adjustStridesForPermutation(permMap, strides);
244+
245+
if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
246+
vector::TransferWriteOp>::value) {
247+
AffineMap permMap = xferOp.getPermutationMap();
248+
// Adjust strides according to the permutation map (e.g., for transpose)
249+
adjustStridesForPermutation(permMap, strides);
250+
}
251+
237252
return {strides, offsetVal};
238253
}
239254

@@ -339,9 +354,51 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
339354
return localOffsets;
340355
}
341356

357+
// Compute the element-wise offsets for vector.gather or vector.scatter ops.
358+
//
359+
// This function linearizes the base offsets of the gather/scatter operation
360+
// and combines them with the per-element indices to produce a final vector of
361+
// memory offsets.
362+
template <
363+
typename OpType,
364+
typename = std::enable_if_t<llvm::is_one_of<
365+
std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
366+
static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
367+
ArrayRef<Value> strides, Value baseOffset) {
368+
Location loc = gatScatOp.getLoc();
369+
SmallVector<Value> offsets = gatScatOp.getOffsets();
370+
for (size_t i = 0; i < offsets.size(); ++i) {
371+
Value offsetContrib =
372+
arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
373+
baseOffset =
374+
arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
375+
}
376+
Value indices = gatScatOp.getIndices();
377+
VectorType vecType = cast<VectorType>(indices.getType());
378+
379+
Value strideVector =
380+
vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
381+
.getResult();
382+
Value stridedIndices =
383+
arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult();
384+
385+
Value baseVector =
386+
vector::BroadcastOp::create(
387+
rewriter, loc,
388+
VectorType::get(vecType.getShape(), rewriter.getIndexType()),
389+
baseOffset)
390+
.getResult();
391+
return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
392+
.getResult();
393+
}
394+
395+
template <
396+
typename OpType,
397+
typename = std::enable_if_t<llvm::is_one_of<
398+
std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
399+
vector::GatherOp, vector::ScatterOp>::value>>
342400
// Convert memref to i64 base pointer
343-
static Value memrefToIndexPtr(VectorTransferOpInterface xferOp,
344-
PatternRewriter &rewriter) {
401+
static Value memrefToIndexPtr(OpType xferOp, PatternRewriter &rewriter) {
345402
Location loc = xferOp.getLoc();
346403
auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
347404
rewriter, loc, xferOp.getBase())
@@ -539,6 +596,71 @@ struct TransferWriteLowering
539596
}
540597
};
541598

599+
struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
600+
using OpRewritePattern<vector::GatherOp>::OpRewritePattern;
601+
602+
LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
603+
PatternRewriter &rewriter) const override {
604+
auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
605+
if (!srcTy)
606+
return rewriter.notifyMatchFailure(gatherOp, "Expects memref source");
607+
608+
Location loc = gatherOp.getLoc();
609+
VectorType vectorType = gatherOp.getVectorType();
610+
611+
auto meta = computeMemrefMeta(gatherOp, rewriter);
612+
if (meta.first.empty())
613+
return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides");
614+
615+
Value localOffsets =
616+
computeOffsets(rewriter, gatherOp, meta.first, meta.second);
617+
Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
618+
619+
auto xeGatherOp = xegpu::LoadGatherOp::create(
620+
rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
621+
/*chunk_size=*/IntegerAttr{},
622+
/*l1_hint=*/xegpu::CachePolicyAttr{},
623+
/*l2_hint=*/xegpu::CachePolicyAttr{},
624+
/*l3_hint=*/xegpu::CachePolicyAttr{});
625+
626+
auto selectOp =
627+
arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
628+
xeGatherOp.getResult(), gatherOp.getPassThru());
629+
rewriter.replaceOp(gatherOp, selectOp.getResult());
630+
return success();
631+
}
632+
};
633+
634+
struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
635+
using OpRewritePattern<vector::ScatterOp>::OpRewritePattern;
636+
637+
LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
638+
PatternRewriter &rewriter) const override {
639+
auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
640+
if (!srcTy)
641+
return rewriter.notifyMatchFailure(scatterOp, "Expects memref source");
642+
643+
Location loc = scatterOp.getLoc();
644+
auto meta = computeMemrefMeta(scatterOp, rewriter);
645+
if (meta.first.empty())
646+
return rewriter.notifyMatchFailure(scatterOp,
647+
"Failed to compute strides");
648+
649+
Value localOffsets =
650+
computeOffsets(rewriter, scatterOp, meta.first, meta.second);
651+
Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
652+
653+
xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
654+
flatMemref, localOffsets, scatterOp.getMask(),
655+
/*chunk_size=*/IntegerAttr{},
656+
/*l1_hint=*/xegpu::CachePolicyAttr{},
657+
/*l2_hint=*/xegpu::CachePolicyAttr{},
658+
/*l3_hint=*/xegpu::CachePolicyAttr{});
659+
rewriter.eraseOp(scatterOp);
660+
return success();
661+
}
662+
};
663+
542664
struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
543665
using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
544666

@@ -654,6 +776,8 @@ struct ConvertVectorToXeGPUPass
654776

655777
void mlir::populateVectorToXeGPUConversionPatterns(
656778
RewritePatternSet &patterns) {
657-
patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
658-
StoreLowering, ContractionLowering>(patterns.getContext());
779+
patterns
780+
.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
781+
ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
782+
patterns.getContext());
659783
}

0 commit comments

Comments
 (0)