@@ -55,7 +55,9 @@ struct ConvertAlloc : public OpRewritePattern<memref::AllocOp> {
5555
5656 LogicalResult matchAndRewrite (memref::AllocOp allocOp,
5757 PatternRewriter &rewriter) const override {
58- if (hasAssignedMemSpace (allocOp->getResult (0 ))) {
58+ Value memref = allocOp->getResult (0 );
59+
60+ if (hasAssignedMemSpace (memref)) {
5961 return rewriter.notifyMatchFailure (
6062 allocOp, " Memref already has some memory space attribute" );
6163 }
@@ -86,15 +88,12 @@ struct ConvertAlloc : public OpRewritePattern<memref::AllocOp> {
8688 return rewriter.notifyMatchFailure (
8789 allocOp, " Only support 2D shared memory for now" );
8890
89- int64_t totalWorkGroupSize = xI * yI * zI;
90-
91- Value memref = allocOp->getResult (0 );
92-
9391 MemRefType originalMemRefType = cast<MemRefType>(memref.getType ());
92+ auto originalShape = originalMemRefType.getShape ();
9493
9594 // Scale the allocation size by the number of threads in the work-group
96- int64_t newX = originalMemRefType. getShape () [0 ] * xI;
97- int64_t newY = originalMemRefType. getShape () [1 ] * yI;
95+ int64_t newX = originalShape [0 ] * xI;
96+ int64_t newY = originalShape [1 ] * yI;
9897
9998 SmallVector<int64_t > newShape = {newX, newY};
10099
@@ -113,12 +112,10 @@ struct ConvertAlloc : public OpRewritePattern<memref::AllocOp> {
113112 .getResult ();
114113
115114 // Compute the offsets in SLM chunk for the current thread
116- auto oneConst =
117- rewriter.create <arith::ConstantIndexOp>(allocOp.getLoc (), 1 );
118- auto origXConst = rewriter.create <arith::ConstantIndexOp>(
119- allocOp.getLoc (), originalMemRefType.getShape ()[0 ]);
120- auto origYConst = rewriter.create <arith::ConstantIndexOp>(
121- allocOp.getLoc (), originalMemRefType.getShape ()[1 ]);
115+ auto origXConst = rewriter.create <arith::ConstantIndexOp>(allocOp.getLoc (),
116+ originalShape[0 ]);
117+ auto origYConst = rewriter.create <arith::ConstantIndexOp>(allocOp.getLoc (),
118+ originalShape[1 ]);
122119
123120 auto threadIds = launchOp.getThreadIds ();
124121
@@ -133,7 +130,7 @@ struct ConvertAlloc : public OpRewritePattern<memref::AllocOp> {
133130
134131 auto offsets = getMixedValues ({ShapedType::kDynamic , ShapedType::kDynamic },
135132 {offX, offY}, rewriter);
136- auto sizes = getMixedValues (originalMemRefType. getShape () , {}, rewriter);
133+ auto sizes = getMixedValues (originalShape , {}, rewriter);
137134 auto strides = getMixedValues ({1 , 1 }, {}, rewriter);
138135
139136 auto newSlice =
0 commit comments