Skip to content

Commit d97a3c2

Browse files
committed
[MLIR][XeGPU] Add lowering from vector.gather/scatter to xegpu.load/store
Signed-off-by: dchigarev <[email protected]>
1 parent 1f49c94 commit d97a3c2

File tree

3 files changed

+420
-8
lines changed

3 files changed

+420
-8
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 135 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,21 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
9797
return success();
9898
}
9999

100+
static LogicalResult gatherScatterPreconditions(PatternRewriter &rewriter,
101+
Operation *op, Type baseType) {
102+
auto srcTy = dyn_cast<MemRefType>(baseType);
103+
if (!srcTy)
104+
return rewriter.notifyMatchFailure(op, "Expects memref source");
105+
106+
SmallVector<int64_t> strides;
107+
int64_t offset;
108+
if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
109+
return rewriter.notifyMatchFailure(
110+
op, "Buffer must be contiguous in the innermost dimension");
111+
112+
return success();
113+
}
114+
100115
static xegpu::CreateNdDescOp
101116
createNdDescriptor(PatternRewriter &rewriter, Location loc,
102117
xegpu::TensorDescType descType, TypedValue<MemRefType> src,
@@ -183,11 +198,15 @@ static void adjustStridesForPermutation(AffineMap permMap,
183198
// Computes memory strides and a memref offset for vector transfer operations,
184199
// handling both static and dynamic memrefs while applying permutation
185200
// transformations for XeGPU lowering.
201+
template <
202+
typename OpType,
203+
typename = std::enable_if_t<llvm::is_one_of<
204+
std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
205+
vector::GatherOp, vector::ScatterOp>::value>>
186206
static std::pair<SmallVector<Value>, Value>
187-
computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
207+
computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) {
188208
SmallVector<Value> strides;
189209
Value baseMemref = xferOp.getBase();
190-
AffineMap permMap = xferOp.getPermutationMap();
191210
MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
192211

193212
Location loc = xferOp.getLoc();
@@ -232,9 +251,15 @@ computeMemrefMeta(VectorTransferOpInterface xferOp, PatternRewriter &rewriter) {
232251
if (!offsetVal)
233252
offsetVal = meta.getOffset();
234253
}
235-
// Adjust strides according to the permutation map (e.g., for transpose)
236-
adjustStridesForPermutation(permMap, strides);
237-
return {strides, offsetVal};
254+
255+
if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
256+
vector::TransferWriteOp>::value) {
257+
AffineMap permMap = xferOp.getPermutationMap();
258+
// Adjust strides according to the permutation map (e.g., for transpose)
259+
adjustStridesForPermutation(permMap, strides);
260+
}
261+
262+
return strides;
238263
}
239264

240265
// This function compute the vectors of localOffsets for scattered load/stores.
@@ -339,8 +364,45 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
339364
return localOffsets;
340365
}
341366

367+
// Compute the element-wise offsets for vector.gather or vector.scatter ops.
368+
//
369+
// This function linearizes the base offsets of the gather/scatter operation
370+
// and combines them with the per-element indices to produce a final vector of
371+
// memory offsets.
372+
template <
373+
typename OpType,
374+
typename = std::enable_if_t<llvm::is_one_of<
375+
std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
376+
static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
377+
ArrayRef<Value> strides) {
378+
Location loc = gatScatOp.getLoc();
379+
SmallVector<Value> offsets = gatScatOp.getOffsets();
380+
Value linearOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
381+
for (size_t i = 0; i < offsets.size(); ++i) {
382+
Value offsetContrib =
383+
arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
384+
linearOffset =
385+
arith::AddIOp::create(rewriter, loc, linearOffset, offsetContrib);
386+
}
387+
Value indices = gatScatOp.getIndices();
388+
VectorType vecType = cast<VectorType>(indices.getType());
389+
390+
Value baseVector =
391+
vector::BroadcastOp::create(
392+
rewriter, loc,
393+
VectorType::get(vecType.getShape(), rewriter.getIndexType()),
394+
linearOffset)
395+
.getResult();
396+
return arith::AddIOp::create(rewriter, loc, baseVector, indices).getResult();
397+
}
398+
399+
template <
400+
typename OpType,
401+
typename = std::enable_if_t<llvm::is_one_of<
402+
std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
403+
vector::GatherOp, vector::ScatterOp>::value>>
342404
// Convert memref to i64 base pointer
343-
static Value memrefToIndexPtr(VectorTransferOpInterface xferOp,
405+
static Value memrefToIndexPtr(OpType xferOp,
344406
PatternRewriter &rewriter) {
345407
Location loc = xferOp.getLoc();
346408
auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
@@ -539,6 +601,69 @@ struct TransferWriteLowering
539601
}
540602
};
541603

604+
struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
605+
using OpRewritePattern<vector::GatherOp>::OpRewritePattern;
606+
607+
LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
608+
PatternRewriter &rewriter) const override {
609+
if (failed(gatherScatterPreconditions(rewriter, gatherOp,
610+
gatherOp.getBase().getType())))
611+
return failure();
612+
613+
Location loc = gatherOp.getLoc();
614+
VectorType vectorType = gatherOp.getVectorType();
615+
616+
SmallVector<Value> strides = computeStrides(gatherOp, rewriter);
617+
if (strides.empty())
618+
return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides");
619+
620+
Value localOffsets = computeOffsets(rewriter, gatherOp, strides);
621+
Value flatMemref = collapseMemrefTo1D(gatherOp, rewriter);
622+
623+
auto xeGatherOp = xegpu::LoadGatherOp::create(
624+
rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
625+
/*chunk_size=*/IntegerAttr{},
626+
/*l1_hint=*/xegpu::CachePolicyAttr{},
627+
/*l2_hint=*/xegpu::CachePolicyAttr{},
628+
/*l3_hint=*/xegpu::CachePolicyAttr{});
629+
630+
auto selectOp =
631+
arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
632+
xeGatherOp.getResult(), gatherOp.getPassThru());
633+
rewriter.replaceOp(gatherOp, selectOp.getResult());
634+
return success();
635+
}
636+
};
637+
638+
struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
639+
using OpRewritePattern<vector::ScatterOp>::OpRewritePattern;
640+
641+
LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
642+
PatternRewriter &rewriter) const override {
643+
if (failed(gatherScatterPreconditions(rewriter, scatterOp,
644+
scatterOp.getBase().getType())))
645+
return failure();
646+
647+
Location loc = scatterOp.getLoc();
648+
SmallVector<Value> strides = computeStrides(scatterOp, rewriter);
649+
if (strides.empty())
650+
return rewriter.notifyMatchFailure(scatterOp,
651+
"Failed to compute strides");
652+
653+
Value localOffsets = computeOffsets(rewriter, scatterOp, strides);
654+
Value flatMemref = collapseMemrefTo1D(scatterOp, rewriter);
655+
656+
xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
657+
flatMemref, localOffsets, scatterOp.getMask(),
658+
/*chunk_size=*/IntegerAttr{},
659+
/*l1_hint=*/xegpu::CachePolicyAttr{},
660+
/*l2_hint=*/xegpu::CachePolicyAttr{},
661+
/*l3_hint=*/xegpu::CachePolicyAttr{});
662+
rewriter.eraseOp(scatterOp);
663+
return success();
664+
}
665+
};
666+
542667
struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
543668
using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
544669

@@ -654,6 +779,8 @@ struct ConvertVectorToXeGPUPass
654779

655780
void mlir::populateVectorToXeGPUConversionPatterns(
656781
RewritePatternSet &patterns) {
657-
patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
658-
StoreLowering, ContractionLowering>(patterns.getContext());
782+
patterns
783+
.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
784+
ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
785+
patterns.getContext());
659786
}
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
2+
3+
gpu.module @xevm_module {
4+
gpu.func @load_1D_vector(%source: memref<8x16x32xf32>,
5+
%off1: index, %off2: index, %off3: index,
6+
%indices: vector<8xindex>, %mask: vector<8xi1>,
7+
%pass_thru: vector<8xf32>) -> vector<8xf32> {
8+
%0 = vector.gather %source[%off1, %off2, %off3][%indices], %mask,
9+
%pass_thru : memref<8x16x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
10+
gpu.return %0 : vector<8xf32>
11+
}
12+
// CHECK-LABEL: @load_1D_vector(
13+
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
14+
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
15+
// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>
16+
// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>
17+
// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8xf32>) -> vector<8xf32> {
18+
// CHECK-COUNT2: arith.muli {{.*}} : index
19+
// CHECK-COUNT2: arith.addi {{.*}} : index
20+
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
21+
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
22+
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
23+
// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
24+
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
25+
// CHECK: gpu.return %[[RES]] : vector<8xf32>
26+
}
27+
28+
// -----
29+
gpu.module @xevm_module {
30+
gpu.func @load_2D_memref(%source: memref<8x32xf32>,
31+
%off1: index, %off2: index,
32+
%indices: vector<8xindex>, %mask: vector<8xi1>,
33+
%pass_thru: vector<8xf32>) -> vector<8xf32> {
34+
%0 = vector.gather %source[%off1, %off2][%indices], %mask,
35+
%pass_thru : memref<8x32xf32>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
36+
gpu.return %0 : vector<8xf32>
37+
}
38+
// CHECK-LABEL: @load_2D_memref(
39+
// CHECK-SAME: %[[SRC:.+]]: memref<8x32xf32>,
40+
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
41+
// CHECK-SAME: %[[INDICES:.+]]: vector<8xindex>
42+
// CHECK-SAME: %[[MASK:.+]]: vector<8xi1>
43+
// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8xf32>) -> vector<8xf32> {
44+
// CHECK-COUNT1: arith.muli {{.*}} : index
45+
// CHECK-COUNT1: arith.addi {{.*}} : index
46+
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8xindex>
47+
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8xindex>
48+
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1]{{\]}} : memref<8x32xf32> into memref<256xf32>
49+
// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<256xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
50+
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8xi1>, vector<8xf32>
51+
// CHECK: gpu.return %[[RES]] : vector<8xf32>
52+
}
53+
54+
// -----
55+
gpu.module @xevm_module {
56+
gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
57+
%off0: index, %off1: index, %off2: index,
58+
%indices: vector<8x16xindex>, %mask: vector<8x16xi1>,
59+
%pass_thru: vector<8x16xf32>) -> vector<8x16xf32> {
60+
%0 = vector.gather %source[%off0, %off1, %off2][%indices], %mask,
61+
%pass_thru : memref<8x16x32xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
62+
gpu.return %0 : vector<8x16xf32>
63+
}
64+
// CHECK-LABEL: @load_2D_vector(
65+
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
66+
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
67+
// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex>
68+
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>
69+
// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8x16xf32>) -> vector<8x16xf32> {
70+
// CHECK-COUNT2: arith.muli {{.*}} : index
71+
// CHECK-COUNT2: arith.addi {{.*}} : index
72+
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
73+
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
74+
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
75+
// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
76+
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
77+
// CHECK: gpu.return %[[RES]] : vector<8x16xf32>
78+
}
79+
80+
// -----
81+
gpu.module @xevm_module {
82+
gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
83+
%off0: index, %off1: index, %off2: index,
84+
%indices: vector<8x16xindex>, %mask: vector<8x16xi1>,
85+
%pass_thru: vector<8x16xf32>) -> vector<8x16xf32> {
86+
%0 = vector.gather %source[%off0, %off1, %off2][%indices], %mask,
87+
%pass_thru : memref<?x?x?xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
88+
gpu.return %0 : vector<8x16xf32>
89+
}
90+
// CHECK-LABEL: @load_dynamic_source(
91+
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
92+
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
93+
// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex>
94+
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>
95+
// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8x16xf32>) -> vector<8x16xf32> {
96+
// CHECK: memref.extract_strided_metadata %[[SRC]]
97+
// CHECK-COUNT2: arith.muli {{.*}} : index
98+
// CHECK-COUNT2: arith.addi {{.*}} : index
99+
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
100+
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
101+
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
102+
// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
103+
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
104+
// CHECK: gpu.return %[[RES]] : vector<8x16xf32>
105+
}
106+
107+
// -----
108+
gpu.module @xevm_module {
109+
gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
110+
%off0: index, %off1: index, %off2: index,
111+
%indices: vector<8x16xindex>, %mask: vector<8x16xi1>,
112+
%pass_thru: vector<8x16xf32>) -> vector<8x16xf32> {
113+
%0 = vector.gather %source[%off0, %off1, %off2][%indices], %mask,
114+
%pass_thru : memref<?x8x16xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
115+
gpu.return %0 : vector<8x16xf32>
116+
}
117+
// CHECK-LABEL: @load_dynamic_source2(
118+
// CHECK-SAME: %[[SRC:.+]]: memref<?x8x16xf32>,
119+
// CHECK-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index, %[[OFF3:.+]]: index,
120+
// CHECK-SAME: %[[INDICES:.+]]: vector<8x16xindex>
121+
// CHECK-SAME: %[[MASK:.+]]: vector<8x16xi1>
122+
// CHECK-SAME: %[[PASS_THRU:.+]]: vector<8x16xf32>) -> vector<8x16xf32> {
123+
// CHECK-NOT: memref.extract_strided_metadata %[[SRC]]
124+
// CHECK-COUNT2: arith.muli {{.*}} : index
125+
// CHECK-COUNT2: arith.addi {{.*}} : index
126+
// CHECK: %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
127+
// CHECK: %[[LIN_IDX:.+]] = arith.addi %[[SPLAT]], %[[INDICES]] : vector<8x16xindex>
128+
// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
129+
// CHECK: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[LIN_IDX]]{{\]}}, %[[MASK]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
130+
// CHECK: %[[RES:.+]] = arith.select %[[MASK]], %[[VEC]], %[[PASS_THRU]] : vector<8x16xi1>, vector<8x16xf32>
131+
// CHECK: gpu.return %[[RES]] : vector<8x16xf32>
132+
}
133+
134+
// -----
135+
gpu.module @xevm_module {
136+
gpu.func @no_load_tensor(%source: tensor<32x64xf32>,
137+
%off: index, %indices: vector<8x16xindex>,
138+
%mask: vector<8x16xi1>, %pass_thru: vector<8x16xf32>) -> vector<8x16xf32> {
139+
%0 = vector.gather %source[%off, %off][%indices], %mask,
140+
%pass_thru : tensor<32x64xf32>, vector<8x16xindex>, vector<8x16xi1>, vector<8x16xf32> into vector<8x16xf32>
141+
gpu.return %0 : vector<8x16xf32>
142+
}
143+
// CHECK-LABEL: @no_load_tensor(
144+
// CHECK: vector.gather
145+
}
146+
147+
// -----
148+
gpu.module @xevm_module {
149+
gpu.func @no_load_non_unit_inner_stride(
150+
%source: memref<32xf32, strided<[?], offset: ?>>,
151+
%off: index, %indices: vector<8xindex>, %mask: vector<8xi1>,
152+
%pass_thru: vector<8xf32>) -> vector<8xf32> {
153+
%0 = vector.gather %source[%off][%indices], %mask, %pass_thru
154+
: memref<32xf32, strided<[?], offset: ?>>, vector<8xindex>, vector<8xi1>, vector<8xf32> into vector<8xf32>
155+
gpu.return %0 : vector<8xf32>
156+
}
157+
// CHECK-LABEL: @no_load_non_unit_inner_stride(
158+
// CHECK: vector.gather
159+
}
160+

0 commit comments

Comments
 (0)