@@ -45,14 +45,19 @@ bool hasAssignedMemSpace(Value value) {
4545 return false ;
4646}
4747
48+ // Converts `memref::AllocOp` within GPU regions to the GPU shared local
49+ // memory. Adjusts the allocation shape based on GPU block dimensions and
50+ // creates a `memref::SubViewOp` for thread-specific memory access.
4851struct ConvertAlloc : public OpRewritePattern <memref::AllocOp> {
4952 using OpRewritePattern<memref::AllocOp>::OpRewritePattern;
5053
5154 ConvertAlloc (MLIRContext *ctx) : OpRewritePattern<memref::AllocOp>(ctx) {}
5255
5356 LogicalResult matchAndRewrite (memref::AllocOp allocOp,
5457 PatternRewriter &rewriter) const override {
55- if (hasAssignedMemSpace (allocOp->getResult (0 ))) {
58+ Value memref = allocOp->getResult (0 );
59+
60+ if (hasAssignedMemSpace (memref)) {
5661 return rewriter.notifyMatchFailure (
5762 allocOp, " Memref already has some memory space attribute" );
5863 }
@@ -62,22 +67,83 @@ struct ConvertAlloc : public OpRewritePattern<memref::AllocOp> {
6267 " Only support allocs in GPU regions" );
6368 }
6469
65- Value memref = allocOp->getResult (0 );
70+ auto launchOp = allocOp->getParentOfType <gpu::LaunchOp>();
71+
72+ auto xSz = dyn_cast<arith::ConstantIndexOp>(
73+ launchOp.getBlockSizeX ().getDefiningOp ());
74+ auto ySz = dyn_cast<arith::ConstantIndexOp>(
75+ launchOp.getBlockSizeY ().getDefiningOp ());
76+ auto zSz = dyn_cast<arith::ConstantIndexOp>(
77+ launchOp.getBlockSizeZ ().getDefiningOp ());
78+
79+ if (!xSz || !ySz || !zSz)
80+ return rewriter.notifyMatchFailure (
81+ allocOp, " Only support constant block sizes for now" );
82+
83+ int64_t xI = xSz.value ();
84+ int64_t yI = ySz.value ();
85+ int64_t zI = zSz.value ();
86+
87+ if (zI != 1 )
88+ return rewriter.notifyMatchFailure (
89+ allocOp, " Only support 2D shared memory for now" );
90+
6691 MemRefType originalMemRefType = cast<MemRefType>(memref.getType ());
92+ auto originalShape = originalMemRefType.getShape ();
93+
94+ // Scale the allocation size by the number of threads in the work-group
95+ int64_t newX = originalShape[0 ] * xI;
96+ int64_t newY = originalShape[1 ] * yI;
97+
98+ SmallVector<int64_t > newShape = {newX, newY};
6799
68100 IntegerAttr sharedAddressSpace =
69101 IntegerAttr::get (rewriter.getIntegerType (64 ),
70102 static_cast <int64_t >(gpu::AddressSpace::Private));
71103
72- // Create a new MemRefType with the desired address space
73- MemRefType newMemRefType = MemRefType::get (
74- originalMemRefType.getShape (), originalMemRefType.getElementType (),
75- originalMemRefType.getLayout (), sharedAddressSpace);
76-
77- Value newMemRef = rewriter.create <memref::AllocOp>(
78- allocOp.getLoc (), newMemRefType, allocOp.getOperands ());
79-
80- memref.replaceAllUsesWith (newMemRef);
104+ MemRefType newRootMemRefType =
105+ MemRefType::get (newShape, originalMemRefType.getElementType (),
106+ originalMemRefType.getLayout (), sharedAddressSpace);
107+
108+ Value newRootMemRef =
109+ rewriter
110+ .create <memref::AllocOp>(allocOp.getLoc (), newRootMemRefType,
111+ allocOp.getOperands ())
112+ .getResult ();
113+
114+ // Compute the offsets in SLM chunk for the current thread
115+ auto origXConst = rewriter.create <arith::ConstantIndexOp>(allocOp.getLoc (),
116+ originalShape[0 ]);
117+ auto origYConst = rewriter.create <arith::ConstantIndexOp>(allocOp.getLoc (),
118+ originalShape[1 ]);
119+
120+ auto threadIds = launchOp.getThreadIds ();
121+
122+ auto offX =
123+ rewriter
124+ .create <arith::MulIOp>(allocOp.getLoc (), threadIds.x , origXConst)
125+ .getResult ();
126+ auto offY =
127+ rewriter
128+ .create <arith::MulIOp>(allocOp.getLoc (), threadIds.y , origYConst)
129+ .getResult ();
130+
131+ auto offsets = getMixedValues ({ShapedType::kDynamic , ShapedType::kDynamic },
132+ {offX, offY}, rewriter);
133+ auto sizes = getMixedValues (originalShape, {}, rewriter);
134+ auto strides = getMixedValues ({1 , 1 }, {}, rewriter);
135+
136+ auto newSlice =
137+ rewriter
138+ .create <memref::SubViewOp>(allocOp.getLoc (), newRootMemRef, offsets,
139+ sizes, strides)
140+ .getResult ();
141+ memref.replaceAllUsesWith (newSlice);
142+
143+ // Erase deallocs since we don't need them for SLM
144+ for (auto user : newSlice.getUsers ())
145+ if (auto deallocOp = dyn_cast<memref::DeallocOp>(user))
146+ deallocOp->erase ();
81147
82148 return success ();
83149 }
0 commit comments