@@ -686,12 +686,12 @@ struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
686686 using UnrollPattern<xegpu::LoadMatrixOp>::UnrollPattern;
687687 LogicalResult matchAndRewrite (xegpu::LoadMatrixOp op,
688688 PatternRewriter &rewriter) const override {
689+ Location loc = op.getLoc ();
690+ VectorType valueTy = op.getType ();
689691 std::optional<SmallVector<int64_t >> targetShape = getTargetShape (op);
690- if (!targetShape)
692+ if (!targetShape || targetShape-> size () != ( size_t )valueTy. getRank () )
691693 return failure ();
692694
693- Location loc = op.getLoc ();
694- VectorType valueTy = op.getType ();
695695 Type elemTy = valueTy.getElementType ();
696696 ArrayRef<int64_t > shape = valueTy.getShape ();
697697 auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr ());
@@ -702,17 +702,17 @@ struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
702702 SmallVector<SmallVector<OpFoldResult>> offsetsList;
703703 for (SmallVector<int64_t > offsets :
704704 StaticTileOffsetRange (shape, *targetShape)) {
705- auto adds = xegpu::addWithRightAligned (
705+ auto adds = xegpu::addElementwise (
706706 rewriter, loc, mixedOffsets,
707707 getAsIndexOpFoldResult (op.getContext (), offsets));
708708 offsetsList.push_back (adds);
709709 }
710710
711711 SmallVector<Value> newOps;
712+ layout = layout.dropInstData ();
712713 for (SmallVector<OpFoldResult> offsets : offsetsList) {
713714 auto newOp = rewriter.create <xegpu::LoadMatrixOp>(
714- op.getLoc (), newValueTy, op.getMemDesc (), offsets,
715- layout.dropInstData ());
715+ op.getLoc (), newValueTy, op.getMemDesc (), offsets, layout);
716716 newOps.push_back (newOp);
717717 }
718718 Value castOp = unpack (newOps, op.getType (), *targetShape, loc, rewriter);
@@ -743,7 +743,7 @@ struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
743743 SmallVector<SmallVector<OpFoldResult>> offsetsList;
744744 for (SmallVector<int64_t > offsets :
745745 StaticTileOffsetRange (shape, *targetShape)) {
746- auto adds = xegpu::addWithRightAligned (
746+ auto adds = xegpu::addElementwise (
747747 rewriter, loc, mixedOffsets,
748748 getAsIndexOpFoldResult (op.getContext (), offsets));
749749 offsetsList.push_back (adds);
0 commit comments