@@ -33,6 +33,7 @@ using namespace mlir;
33
33
34
34
namespace {
35
35
36
+ // Return true if value represents a zero constant.
36
37
static bool isZeroConstant (Value val) {
37
38
auto constant = val.getDefiningOp <arith::ConstantOp>();
38
39
if (!constant)
@@ -46,6 +47,17 @@ static bool isZeroConstant(Value val) {
46
47
.Default ([](auto ) { return false ; });
47
48
}
48
49
50
+ static LogicalResult storeLoadPreconditions (PatternRewriter &rewriter,
51
+ Operation *op, VectorType vecTy) {
52
+ // Validate only vector as the basic vector store and load ops guarantee
53
+ // XeGPU-compatible memref source.
54
+ unsigned vecRank = vecTy.getRank ();
55
+ if (!(vecRank == 1 || vecRank == 2 ))
56
+ return rewriter.notifyMatchFailure (op, " Expects 1D or 2D vector" );
57
+
58
+ return success ();
59
+ }
60
+
49
61
static LogicalResult transferPreconditions (PatternRewriter &rewriter,
50
62
VectorTransferOpInterface xferOp) {
51
63
if (xferOp.getMask ())
@@ -55,18 +67,21 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
55
67
auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType ());
56
68
if (!srcTy)
57
69
return rewriter.notifyMatchFailure (xferOp, " Expects memref source" );
70
+
71
+ // Perform common data transfer checks.
58
72
VectorType vecTy = xferOp.getVectorType ();
59
- unsigned vecRank = vecTy.getRank ();
60
- if (!(vecRank == 1 || vecRank == 2 ))
61
- return rewriter.notifyMatchFailure (xferOp, " Expects 1D or 2D vector" );
73
+ if (failed (storeLoadPreconditions (rewriter, xferOp, vecTy)))
74
+ return failure ();
62
75
76
+ // Validate further transfer op semantics.
63
77
SmallVector<int64_t > strides;
64
78
int64_t offset;
65
79
if (failed (getStridesAndOffset (srcTy, strides, offset)) ||
66
80
strides.back () != 1 )
67
81
return rewriter.notifyMatchFailure (
68
82
xferOp, " Buffer must be contiguous in the innermost dimension" );
69
83
84
+ unsigned vecRank = vecTy.getRank ();
70
85
AffineMap map = xferOp.getPermutationMap ();
71
86
if (!map.isProjectedPermutation (/* allowZeroInResults=*/ false ))
72
87
return rewriter.notifyMatchFailure (xferOp, " Unsupported permutation map" );
@@ -232,6 +247,66 @@ struct TransferWriteLowering
232
247
}
233
248
};
234
249
250
+ struct LoadLowering : public OpRewritePattern <vector::LoadOp> {
251
+ using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
252
+
253
+ LogicalResult matchAndRewrite (vector::LoadOp loadOp,
254
+ PatternRewriter &rewriter) const override {
255
+ Location loc = loadOp.getLoc ();
256
+
257
+ VectorType vecTy = loadOp.getResult ().getType ();
258
+ if (failed (storeLoadPreconditions (rewriter, loadOp, vecTy)))
259
+ return failure ();
260
+
261
+ auto descType = xegpu::TensorDescType::get (
262
+ vecTy.getShape (), vecTy.getElementType (), /* array_length=*/ 1 ,
263
+ /* boundary_check=*/ true , xegpu::MemorySpace::Global);
264
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor (
265
+ rewriter, loc, descType, loadOp.getBase (), loadOp.getIndices ());
266
+
267
+ // By default, no specific caching policy is assigned.
268
+ xegpu::CachePolicyAttr hint = nullptr ;
269
+ auto loadNdOp = rewriter.create <xegpu::LoadNdOp>(
270
+ loc, vecTy, ndDesc, /* packed=*/ nullptr , /* transpose=*/ nullptr ,
271
+ /* l1_hint=*/ hint,
272
+ /* l2_hint=*/ hint, /* l3_hint=*/ hint);
273
+ rewriter.replaceOp (loadOp, loadNdOp);
274
+
275
+ return success ();
276
+ }
277
+ };
278
+
279
+ struct StoreLowering : public OpRewritePattern <vector::StoreOp> {
280
+ using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
281
+
282
+ LogicalResult matchAndRewrite (vector::StoreOp storeOp,
283
+ PatternRewriter &rewriter) const override {
284
+ Location loc = storeOp.getLoc ();
285
+
286
+ TypedValue<VectorType> vector = storeOp.getValueToStore ();
287
+ VectorType vecTy = vector.getType ();
288
+ if (failed (storeLoadPreconditions (rewriter, storeOp, vecTy)))
289
+ return failure ();
290
+
291
+ auto descType =
292
+ xegpu::TensorDescType::get (vecTy.getShape (), vecTy.getElementType (),
293
+ /* array_length=*/ 1 , /* boundary_check=*/ true ,
294
+ xegpu::MemorySpace::Global);
295
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor (
296
+ rewriter, loc, descType, storeOp.getBase (), storeOp.getIndices ());
297
+
298
+ // By default, no specific caching policy is assigned.
299
+ xegpu::CachePolicyAttr hint = nullptr ;
300
+ auto storeNdOp =
301
+ rewriter.create <xegpu::StoreNdOp>(loc, vector, ndDesc,
302
+ /* l1_hint=*/ hint,
303
+ /* l2_hint=*/ hint, /* l3_hint=*/ hint);
304
+ rewriter.replaceOp (storeOp, storeNdOp);
305
+
306
+ return success ();
307
+ }
308
+ };
309
+
235
310
struct ConvertVectorToXeGPUPass
236
311
: public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
237
312
void runOnOperation () override {
@@ -247,8 +322,8 @@ struct ConvertVectorToXeGPUPass
247
322
248
323
void mlir::populateVectorToXeGPUConversionPatterns (
249
324
RewritePatternSet &patterns) {
250
- patterns.add <TransferReadLowering, TransferWriteLowering>(
251
- patterns.getContext ());
325
+ patterns.add <TransferReadLowering, TransferWriteLowering, LoadLowering,
326
+ StoreLowering>( patterns.getContext ());
252
327
}
253
328
254
329
std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass () {
0 commit comments