@@ -445,12 +445,16 @@ static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
445445 Value mask = vector::ConstantMaskOp::create (
446446 rewriter, loc, VectorType::get (vectorShape, rewriter.getI1Type ()),
447447 vectorShape);
448- auto gatherOp = xegpu::LoadGatherOp::create (
449- rewriter, loc, vectorType, flatMemref, localOffsets, mask,
450- /* chunk_size=*/ IntegerAttr{},
451- /* l1_hint=*/ xegpu::CachePolicyAttr{},
452- /* l2_hint=*/ xegpu::CachePolicyAttr{},
453- /* l3_hint=*/ xegpu::CachePolicyAttr{});
448+ SmallVector<xegpu::CachePolicyAttr, 3 > cacheHints = getOpCacheHints (readOp);
449+ auto gatherOp = xegpu::LoadGatherOp::create (rewriter, loc, vectorType,
450+ flatMemref, localOffsets, mask,
451+ /* chunk_size=*/ IntegerAttr{},
452+ /* l1_hint=*/ cacheHints[0 ],
453+ /* l2_hint=*/ cacheHints[1 ],
454+ /* l3_hint=*/ cacheHints[2 ]);
455+ auto resLayout = xegpu::getDistributeLayoutAttr (readOp.getResult ());
456+ xegpu::setDistributeLayoutAttrs (gatherOp,
457+ [&](Value val) { return resLayout; });
454458
455459 rewriter.replaceOp (readOp, gatherOp.getResult ());
456460 return success ();
@@ -479,12 +483,16 @@ static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
479483 Value mask = vector::ConstantMaskOp::create (
480484 rewriter, loc, VectorType::get (vectorShape, rewriter.getI1Type ()),
481485 vectorShape);
482- xegpu::StoreScatterOp::create (rewriter, loc, writeOp.getVector (), flatMemref,
483- localOffsets, mask,
484- /* chunk_size=*/ IntegerAttr{},
485- /* l1_hint=*/ xegpu::CachePolicyAttr{},
486- /* l2_hint=*/ xegpu::CachePolicyAttr{},
487- /* l3_hint=*/ xegpu::CachePolicyAttr{});
486+ auto cacheHints = getOpCacheHints (writeOp);
487+ auto storeOp = xegpu::StoreScatterOp::create (
488+ rewriter, loc, writeOp.getVector (), flatMemref, localOffsets, mask,
489+ /* chunk_size=*/ IntegerAttr{},
490+ /* l1_hint=*/ cacheHints[0 ],
491+ /* l2_hint=*/ cacheHints[1 ],
492+ /* l3_hint=*/ cacheHints[2 ]);
493+ auto valueLayout = xegpu::getDistributeLayoutAttr (writeOp->getOpOperand (0 ));
494+ xegpu::setDistributeLayoutAttrs (storeOp,
495+ [&](Value val) { return valueLayout; });
488496 rewriter.eraseOp (writeOp);
489497 return success ();
490498}
@@ -534,9 +542,11 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
534542 SmallVector<int64_t > descShape (vecTy.getShape ());
535543 if (isTransposeLoad)
536544 std::reverse (descShape.begin (), descShape.end ());
537- auto descType = xegpu::TensorDescType::get (
538- descShape, elementType, /* array_length=*/ 1 ,
539- /* boundary_check=*/ isOutOfBounds, xegpu::MemorySpace::Global);
545+ auto resLayout = xegpu::getDistributeLayoutAttr (readOp.getResult ());
546+ auto descType =
547+ xegpu::TensorDescType::get (descShape, elementType, /* array_length=*/ 1 ,
548+ /* boundary_check=*/ isOutOfBounds,
549+ xegpu::MemorySpace::Global, resLayout);
540550
541551 xegpu::CreateNdDescOp ndDesc =
542552 createNdDescriptor (rewriter, loc, descType,
@@ -547,12 +557,12 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
547557 !isTransposeLoad ? nullptr
548558 : DenseI64ArrayAttr::get (rewriter.getContext (),
549559 ArrayRef<int64_t >{1 , 0 });
550- // By default, no specific caching policy is assigned.
551- xegpu::CachePolicyAttr hint = nullptr ;
560+ auto cacheHints = getOpCacheHints (readOp);
552561 auto loadOp = xegpu::LoadNdOp::create (rewriter, loc, vecTy, ndDesc,
553562 /* packed=*/ nullptr , transposeAttr,
554- /* l1_hint=*/ hint,
555- /* l2_hint=*/ hint, /* l3_hint=*/ hint);
563+ /* l1_hint=*/ cacheHints[0 ],
564+ /* l2_hint=*/ cacheHints[1 ],
565+ /* l3_hint=*/ cacheHints[2 ]);
556566 rewriter.replaceOp (readOp, loadOp);
557567
558568 return success ();
@@ -590,21 +600,24 @@ struct TransferWriteLowering
590600 if (!map.isMinorIdentity ())
591601 return rewriter.notifyMatchFailure (writeOp, " Expects identity map" );
592602
603+ auto valLayout = xegpu::getDistributeLayoutAttr (writeOp->getOpOperand (0 ));
593604 auto descType = xegpu::TensorDescType::get (
594605 vecTy.getShape (), vecTy.getElementType (),
595606 /* array_length=*/ 1 , /* boundary_check=*/ writeOp.hasOutOfBoundsDim (),
596- xegpu::MemorySpace::Global);
607+ xegpu::MemorySpace::Global, valLayout );
597608 xegpu::CreateNdDescOp ndDesc =
598609 createNdDescriptor (rewriter, loc, descType,
599610 dyn_cast<TypedValue<MemRefType>>(writeOp.getBase ()),
600611 writeOp.getIndices ());
601612
602613 // By default, no specific caching policy is assigned.
603614 xegpu::CachePolicyAttr hint = nullptr ;
615+ auto cacheHints = getOpCacheHints (writeOp);
604616 auto storeOp =
605617 xegpu::StoreNdOp::create (rewriter, loc, writeOp.getVector (), ndDesc,
606- /* l1_hint=*/ hint,
607- /* l2_hint=*/ hint, /* l3_hint=*/ hint);
618+ /* l1_hint=*/ cacheHints[0 ],
619+ /* l2_hint=*/ cacheHints[1 ],
620+ /* l3_hint=*/ cacheHints[2 ]);
608621 rewriter.replaceOp (writeOp, storeOp);
609622
610623 return success ();
@@ -720,18 +733,20 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
720733 // Boundary check is available only for block instructions.
721734 bool boundaryCheck = vecTy.getRank () > 1 ;
722735
736+ auto resLayout = xegpu::getDistributeLayoutAttr (loadOp.getResult ());
723737 auto descType = xegpu::TensorDescType::get (
724738 vecTy.getShape (), vecTy.getElementType (), /* array_length=*/ 1 ,
725- boundaryCheck, xegpu::MemorySpace::Global);
739+ boundaryCheck, xegpu::MemorySpace::Global, resLayout );
726740 xegpu::CreateNdDescOp ndDesc = createNdDescriptor (
727741 rewriter, loc, descType, loadOp.getBase (), loadOp.getIndices ());
728742
729743 // By default, no specific caching policy is assigned.
730744 xegpu::CachePolicyAttr hint = nullptr ;
745+ auto cacheHints = getOpCacheHints (loadOp);
731746 auto loadNdOp = xegpu::LoadNdOp::create (
732747 rewriter, loc, vecTy, ndDesc, /* packed=*/ nullptr , /* transpose=*/ nullptr ,
733- /* l1_hint=*/ hint ,
734- /* l2_hint=*/ hint , /* l3_hint=*/ hint );
748+ /* l1_hint=*/ cacheHints[ 0 ] ,
749+ /* l2_hint=*/ cacheHints[ 1 ] , /* l3_hint=*/ cacheHints[ 2 ] );
735750 rewriter.replaceOp (loadOp, loadNdOp);
736751
737752 return success ();
@@ -753,18 +768,21 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
753768 // Boundary check is available only for block instructions.
754769 bool boundaryCheck = vecTy.getRank () > 1 ;
755770
756- auto descType = xegpu::TensorDescType::get (
757- vecTy.getShape (), vecTy.getElementType (),
758- /* array_length=*/ 1 , boundaryCheck, xegpu::MemorySpace::Global);
771+ auto valLayout = xegpu::getDistributeLayoutAttr (storeOp->getOpOperand (0 ));
772+ auto descType =
773+ xegpu::TensorDescType::get (vecTy.getShape (), vecTy.getElementType (),
774+ /* array_length=*/ 1 , boundaryCheck,
775+ xegpu::MemorySpace::Global, valLayout);
759776 xegpu::CreateNdDescOp ndDesc = createNdDescriptor (
760777 rewriter, loc, descType, storeOp.getBase (), storeOp.getIndices ());
761778
762779 // By default, no specific caching policy is assigned.
763780 xegpu::CachePolicyAttr hint = nullptr ;
764- auto storeNdOp =
765- xegpu::StoreNdOp::create (rewriter, loc, vector, ndDesc,
766- /* l1_hint=*/ hint,
767- /* l2_hint=*/ hint, /* l3_hint=*/ hint);
781+ auto cacheHints = getOpCacheHints (storeOp);
782+ auto storeNdOp = xegpu::StoreNdOp::create (rewriter, loc, vector, ndDesc,
783+ /* l1_hint=*/ cacheHints[0 ],
784+ /* l2_hint=*/ cacheHints[1 ],
785+ /* l3_hint=*/ cacheHints[2 ]);
768786 rewriter.replaceOp (storeOp, storeNdOp);
769787
770788 return success ();
0 commit comments