@@ -33,6 +33,7 @@ using namespace mlir;
3333
3434namespace {
3535
36+ // Return true if value represents a zero constant.
3637static bool isZeroConstant (Value val) {
3738 auto constant = val.getDefiningOp <arith::ConstantOp>();
3839 if (!constant)
@@ -46,6 +47,17 @@ static bool isZeroConstant(Value val) {
4647 .Default ([](auto ) { return false ; });
4748}
4849
50+ static LogicalResult storeLoadPreconditions (PatternRewriter &rewriter,
51+ Operation *op, VectorType vecTy) {
52+ // Validate only vector as the basic vector store and load ops guarantee
53+ // XeGPU-compatible memref source.
54+ unsigned vecRank = vecTy.getRank ();
55+ if (!(vecRank == 1 || vecRank == 2 ))
56+ return rewriter.notifyMatchFailure (op, " Expects 1D or 2D vector" );
57+
58+ return success ();
59+ }
60+
4961static LogicalResult transferPreconditions (PatternRewriter &rewriter,
5062 VectorTransferOpInterface xferOp) {
5163 if (xferOp.getMask ())
@@ -55,18 +67,21 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
5567 auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType ());
5668 if (!srcTy)
5769 return rewriter.notifyMatchFailure (xferOp, " Expects memref source" );
70+
71+ // Perform common data transfer checks.
5872 VectorType vecTy = xferOp.getVectorType ();
59- unsigned vecRank = vecTy.getRank ();
60- if (!(vecRank == 1 || vecRank == 2 ))
61- return rewriter.notifyMatchFailure (xferOp, " Expects 1D or 2D vector" );
73+ if (failed (storeLoadPreconditions (rewriter, xferOp, vecTy)))
74+ return failure ();
6275
76+ // Validate further transfer op semantics.
6377 SmallVector<int64_t > strides;
6478 int64_t offset;
6579 if (failed (getStridesAndOffset (srcTy, strides, offset)) ||
6680 strides.back () != 1 )
6781 return rewriter.notifyMatchFailure (
6882 xferOp, " Buffer must be contiguous in the innermost dimension" );
6983
84+ unsigned vecRank = vecTy.getRank ();
7085 AffineMap map = xferOp.getPermutationMap ();
7186 if (!map.isProjectedPermutation (/* allowZeroInResults=*/ false ))
7287 return rewriter.notifyMatchFailure (xferOp, " Unsupported permutation map" );
@@ -232,6 +247,66 @@ struct TransferWriteLowering
232247 }
233248};
234249
250+ struct LoadLowering : public OpRewritePattern <vector::LoadOp> {
251+ using OpRewritePattern<vector::LoadOp>::OpRewritePattern;
252+
253+ LogicalResult matchAndRewrite (vector::LoadOp loadOp,
254+ PatternRewriter &rewriter) const override {
255+ Location loc = loadOp.getLoc ();
256+
257+ VectorType vecTy = loadOp.getResult ().getType ();
258+ if (failed (storeLoadPreconditions (rewriter, loadOp, vecTy)))
259+ return failure ();
260+
261+ auto descType = xegpu::TensorDescType::get (
262+ vecTy.getShape (), vecTy.getElementType (), /* array_length=*/ 1 ,
263+ /* boundary_check=*/ true , xegpu::MemorySpace::Global);
264+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor (
265+ rewriter, loc, descType, loadOp.getBase (), loadOp.getIndices ());
266+
267+ // By default, no specific caching policy is assigned.
268+ xegpu::CachePolicyAttr hint = nullptr ;
269+ auto loadNdOp = rewriter.create <xegpu::LoadNdOp>(
270+ loc, vecTy, ndDesc, /* packed=*/ nullptr , /* transpose=*/ nullptr ,
271+ /* l1_hint=*/ hint,
272+ /* l2_hint=*/ hint, /* l3_hint=*/ hint);
273+ rewriter.replaceOp (loadOp, loadNdOp);
274+
275+ return success ();
276+ }
277+ };
278+
279+ struct StoreLowering : public OpRewritePattern <vector::StoreOp> {
280+ using OpRewritePattern<vector::StoreOp>::OpRewritePattern;
281+
282+ LogicalResult matchAndRewrite (vector::StoreOp storeOp,
283+ PatternRewriter &rewriter) const override {
284+ Location loc = storeOp.getLoc ();
285+
286+ TypedValue<VectorType> vector = storeOp.getValueToStore ();
287+ VectorType vecTy = vector.getType ();
288+ if (failed (storeLoadPreconditions (rewriter, storeOp, vecTy)))
289+ return failure ();
290+
291+ auto descType =
292+ xegpu::TensorDescType::get (vecTy.getShape (), vecTy.getElementType (),
293+ /* array_length=*/ 1 , /* boundary_check=*/ true ,
294+ xegpu::MemorySpace::Global);
295+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor (
296+ rewriter, loc, descType, storeOp.getBase (), storeOp.getIndices ());
297+
298+ // By default, no specific caching policy is assigned.
299+ xegpu::CachePolicyAttr hint = nullptr ;
300+ auto storeNdOp =
301+ rewriter.create <xegpu::StoreNdOp>(loc, vector, ndDesc,
302+ /* l1_hint=*/ hint,
303+ /* l2_hint=*/ hint, /* l3_hint=*/ hint);
304+ rewriter.replaceOp (storeOp, storeNdOp);
305+
306+ return success ();
307+ }
308+ };
309+
235310struct ConvertVectorToXeGPUPass
236311 : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
237312 void runOnOperation () override {
@@ -247,8 +322,8 @@ struct ConvertVectorToXeGPUPass
247322
248323void mlir::populateVectorToXeGPUConversionPatterns (
249324 RewritePatternSet &patterns) {
250- patterns.add <TransferReadLowering, TransferWriteLowering>(
251- patterns.getContext ());
325+ patterns.add <TransferReadLowering, TransferWriteLowering, LoadLowering,
326+ StoreLowering>( patterns.getContext ());
252327}
253328
254329std::unique_ptr<Pass> mlir::createConvertVectorToXeGPUPass () {
0 commit comments