Skip to content

Commit 6c25604

Browse files
authored
[mlir][xegpu] Convert Vector load and store to XeGPU (#110826)
Adds patterns to lower vector.load|store to XeGPU operations.
1 parent 650c41a commit 6c25604

File tree

3 files changed

+291
-5
lines changed

3 files changed

+291
-5
lines changed

mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ using namespace mlir;
3333

3434
namespace {
3535

36+
// Return true if value represents a zero constant.
3637
static bool isZeroConstant(Value val) {
3738
auto constant = val.getDefiningOp<arith::ConstantOp>();
3839
if (!constant)
@@ -46,6 +47,17 @@ static bool isZeroConstant(Value val) {
4647
.Default([](auto) { return false; });
4748
}
4849

50+
static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
51+
Operation *op, VectorType vecTy) {
52+
// Validate only vector as the basic vector store and load ops guarantee
53+
// XeGPU-compatible memref source.
54+
unsigned vecRank = vecTy.getRank();
55+
if (!(vecRank == 1 || vecRank == 2))
56+
return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
57+
58+
return success();
59+
}
60+
4961
static LogicalResult transferPreconditions(PatternRewriter &rewriter,
5062
VectorTransferOpInterface xferOp) {
5163
if (xferOp.getMask())
@@ -55,18 +67,21 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
5567
auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
5668
if (!srcTy)
5769
return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
70+
71+
// Perform common data transfer checks.
5872
VectorType vecTy = xferOp.getVectorType();
59-
unsigned vecRank = vecTy.getRank();
60-
if (!(vecRank == 1 || vecRank == 2))
61-
return rewriter.notifyMatchFailure(xferOp, "Expects 1D or 2D vector");
73+
if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
74+
return failure();
6275

76+
// Validate further transfer op semantics.
6377
SmallVector<int64_t> strides;
6478
int64_t offset;
6579
if (failed(getStridesAndOffset(srcTy, strides, offset)) ||
6680
strides.back() != 1)
6781
return rewriter.notifyMatchFailure(
6882
xferOp, "Buffer must be contiguous in the innermost dimension");
6983

84+
unsigned vecRank = vecTy.getRank();
7085
AffineMap map = xferOp.getPermutationMap();
7186
if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
7287
return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
@@ -232,6 +247,66 @@ struct TransferWriteLowering
232247
}
233248
};
234249

250+
struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
251+
using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
252+
253+
LogicalResult matchAndRewrite(vector::LoadOp loadOp,
254+
PatternRewriter &rewriter) const override {
255+
Location loc = loadOp.getLoc();
256+
257+
VectorType vecTy = loadOp.getResult().getType();
258+
if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
259+
return failure();
260+
261+
auto descType = xegpu::TensorDescType::get(
262+
vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
263+
/*boundary_check=*/true, xegpu::MemorySpace::Global);
264+
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
265+
rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
266+
267+
// By default, no specific caching policy is assigned.
268+
xegpu::CachePolicyAttr hint = nullptr;
269+
auto loadNdOp = rewriter.create<xegpu::LoadNdOp>(
270+
loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
271+
/*l1_hint=*/hint,
272+
/*l2_hint=*/hint, /*l3_hint=*/hint);
273+
rewriter.replaceOp(loadOp, loadNdOp);
274+
275+
return success();
276+
}
277+
};
278+
279+
struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
280+
using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
281+
282+
LogicalResult matchAndRewrite(vector::StoreOp storeOp,
283+
PatternRewriter &rewriter) const override {
284+
Location loc = storeOp.getLoc();
285+
286+
TypedValue<VectorType> vector = storeOp.getValueToStore();
287+
VectorType vecTy = vector.getType();
288+
if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
289+
return failure();
290+
291+
auto descType =
292+
xegpu::TensorDescType::get(vecTy.getShape(), vecTy.getElementType(),
293+
/*array_length=*/1, /*boundary_check=*/true,
294+
xegpu::MemorySpace::Global);
295+
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
296+
rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
297+
298+
// By default, no specific caching policy is assigned.
299+
xegpu::CachePolicyAttr hint = nullptr;
300+
auto storeNdOp =
301+
rewriter.create<xegpu::StoreNdOp>(loc, vector, ndDesc,
302+
/*l1_hint=*/hint,
303+
/*l2_hint=*/hint, /*l3_hint=*/hint);
304+
rewriter.replaceOp(storeOp, storeNdOp);
305+
306+
return success();
307+
}
308+
};
309+
235310
struct ConvertVectorToXeGPUPass
236311
: public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
237312
void runOnOperation() override {
@@ -247,8 +322,8 @@ struct ConvertVectorToXeGPUPass
247322

248323
void mlir::populateVectorToXeGPUConversionPatterns(
249324
RewritePatternSet &patterns) {
250-
patterns.add<TransferReadLowering, TransferWriteLowering>(
251-
patterns.getContext());
325+
patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
326+
StoreLowering>(patterns.getContext());
252327
}
253328

254329
std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass() {
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
2+
3+
func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector<8xf32> {
4+
%0 = vector.load %source[%offset, %offset, %offset]
5+
: memref<8x16x32xf32>, vector<8xf32>
6+
return %0 : vector<8xf32>
7+
}
8+
9+
// CHECK-LABEL: @load_1D_vector(
10+
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
11+
// CHECK-SAME: %[[OFFSET:.+]]: index
12+
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
13+
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
14+
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
15+
// CHECK-SAME: boundary_check = true
16+
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
17+
// CHECK: return %[[VEC]]
18+
19+
// -----
20+
21+
func.func @load_2D_vector(%source: memref<8x16x32xf32>,
22+
%offset: index) -> vector<8x16xf32> {
23+
%0 = vector.load %source[%offset, %offset, %offset]
24+
: memref<8x16x32xf32>, vector<8x16xf32>
25+
return %0 : vector<8x16xf32>
26+
}
27+
28+
// CHECK-LABEL: @load_2D_vector(
29+
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
30+
// CHECK-SAME: %[[OFFSET:.+]]: index
31+
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
32+
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
33+
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
34+
// CHECK-SAME: boundary_check = true
35+
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
36+
// CHECK: return %[[VEC]]
37+
38+
// -----
39+
40+
func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
41+
%offset: index) -> vector<8x16xf32> {
42+
%0 = vector.load %source[%offset, %offset, %offset]
43+
: memref<?x?x?xf32>, vector<8x16xf32>
44+
return %0 : vector<8x16xf32>
45+
}
46+
47+
// CHECK-LABEL: @load_dynamic_source(
48+
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
49+
// CHECK-SAME: %[[OFFSET:.+]]: index
50+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
51+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
52+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
53+
// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
54+
// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
55+
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
56+
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
57+
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
58+
// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
59+
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
60+
// CHECK-SAME: boundary_check = true
61+
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
62+
// CHECK: return %[[VEC]]
63+
64+
// -----
65+
66+
func.func @load_out_of_bounds(%source: memref<7x15xf32>,
67+
%offset: index) -> vector<8x16xf32> {
68+
%0 = vector.load %source[%offset, %offset]
69+
: memref<7x15xf32>, vector<8x16xf32>
70+
return %0 : vector<8x16xf32>
71+
}
72+
73+
// CHECK-LABEL: @load_out_of_bounds(
74+
// CHECK-SAME: %[[SRC:.+]]: memref<7x15xf32>,
75+
// CHECK-SAME: %[[OFFSET:.+]]: index
76+
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
77+
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
78+
// CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32,
79+
// CHECK-SAME: boundary_check = true
80+
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
81+
// CHECK: return %[[VEC]]
82+
83+
// -----
84+
85+
func.func @no_load_high_dim_vector(%source: memref<16x32x64xf32>,
86+
%offset: index) -> vector<8x16x32xf32> {
87+
%0 = vector.load %source[%offset, %offset, %offset]
88+
: memref<16x32x64xf32>, vector<8x16x32xf32>
89+
return %0 : vector<8x16x32xf32>
90+
}
91+
92+
// CHECK-LABEL: @no_load_high_dim_vector(
93+
// CHECK: vector.load
94+
95+
// -----
96+
97+
func.func @no_load_zero_dim_vector(%source: memref<64xf32>,
98+
%offset: index) -> vector<f32> {
99+
%0 = vector.load %source[%offset]
100+
: memref<64xf32>, vector<f32>
101+
return %0 : vector<f32>
102+
}
103+
104+
// CHECK-LABEL: @no_load_zero_dim_vector(
105+
// CHECK: vector.load
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
2+
3+
func.func @store_1D_vector(%vec: vector<8xf32>,
4+
%source: memref<8x16x32xf32>, %offset: index) {
5+
vector.store %vec, %source[%offset, %offset, %offset]
6+
: memref<8x16x32xf32>, vector<8xf32>
7+
return
8+
}
9+
10+
// CHECK-LABEL: @store_1D_vector(
11+
// CHECK-SAME: %[[VEC:.+]]: vector<8xf32>,
12+
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
13+
// CHECK-SAME: %[[OFFSET:.+]]: index
14+
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
15+
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
16+
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
17+
// CHECK-SAME: boundary_check = true
18+
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
19+
20+
// -----
21+
22+
func.func @store_2D_vector(%vec: vector<8x16xf32>,
23+
%source: memref<8x16x32xf32>, %offset: index) {
24+
vector.store %vec, %source[%offset, %offset, %offset]
25+
: memref<8x16x32xf32>, vector<8x16xf32>
26+
return
27+
}
28+
29+
// CHECK-LABEL: @store_2D_vector(
30+
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
31+
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
32+
// CHECK-SAME: %[[OFFSET:.+]]: index
33+
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
34+
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
35+
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
36+
// CHECK-SAME: boundary_check = true
37+
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
38+
39+
// -----
40+
41+
func.func @store_dynamic_source(%vec: vector<8x16xf32>,
42+
%source: memref<?x?x?xf32>, %offset: index) {
43+
vector.store %vec, %source[%offset, %offset, %offset]
44+
: memref<?x?x?xf32>, vector<8x16xf32>
45+
return
46+
}
47+
48+
// CHECK-LABEL: @store_dynamic_source(
49+
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
50+
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
51+
// CHECK-SAME: %[[OFFSET:.+]]: index
52+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
53+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
54+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
55+
// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
56+
// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
57+
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
58+
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
59+
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
60+
// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
61+
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
62+
// CHECK-SAME: boundary_check = true
63+
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
64+
65+
// -----
66+
67+
func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
68+
%source: memref<7x64xf32>, %offset: index) {
69+
vector.store %vec, %source[%offset, %offset]
70+
: memref<7x64xf32>, vector<8x16xf32>
71+
return
72+
}
73+
74+
// CHECK-LABEL: @store_out_of_bounds(
75+
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
76+
// CHECK-SAME: %[[SRC:.+]]: memref<7x64xf32>,
77+
// CHECK-SAME: %[[OFFSET:.+]]: index
78+
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
79+
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
80+
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32,
81+
// CHECK-SAME: boundary_check = true
82+
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
83+
84+
// -----
85+
86+
func.func @no_store_high_dim_vector(%vec: vector<8x16x32xf32>,
87+
%source: memref<16x32x64xf32>, %offset: index) {
88+
vector.store %vec, %source[%offset, %offset, %offset]
89+
: memref<16x32x64xf32>, vector<8x16x32xf32>
90+
return
91+
}
92+
93+
// CHECK-LABEL: @no_store_high_dim_vector(
94+
// CHECK: vector.store
95+
96+
// -----
97+
98+
func.func @no_store_zero_dim_vector(%vec: vector<f32>,
99+
%source: memref<64xf32>, %offset: index) {
100+
vector.store %vec, %source[%offset]
101+
: memref<64xf32>, vector<f32>
102+
return
103+
}
104+
105+
// CHECK-LABEL: @no_store_zero_dim_vector(
106+
// CHECK: vector.store

0 commit comments

Comments
 (0)