Skip to content

Commit deab2b8

Browse files
committed
make code cleaner
Signed-off-by: dchigarev <[email protected]>
1 parent 6a61a7e commit deab2b8

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

lib/gc/Transforms/GPU/AllocsToSLM.cpp

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ struct ConvertAlloc : public OpRewritePattern<memref::AllocOp> {
5555

5656
LogicalResult matchAndRewrite(memref::AllocOp allocOp,
5757
PatternRewriter &rewriter) const override {
58-
if (hasAssignedMemSpace(allocOp->getResult(0))) {
58+
Value memref = allocOp->getResult(0);
59+
60+
if (hasAssignedMemSpace(memref)) {
5961
return rewriter.notifyMatchFailure(
6062
allocOp, "Memref already has some memory space attribute");
6163
}
@@ -86,15 +88,12 @@ struct ConvertAlloc : public OpRewritePattern<memref::AllocOp> {
8688
return rewriter.notifyMatchFailure(
8789
allocOp, "Only support 2D shared memory for now");
8890

89-
int64_t totalWorkGroupSize = xI * yI * zI;
90-
91-
Value memref = allocOp->getResult(0);
92-
9391
MemRefType originalMemRefType = cast<MemRefType>(memref.getType());
92+
auto originalShape = originalMemRefType.getShape();
9493

9594
// Scale the allocation size by the number of threads in the work-group
96-
int64_t newX = originalMemRefType.getShape()[0] * xI;
97-
int64_t newY = originalMemRefType.getShape()[1] * yI;
95+
int64_t newX = originalShape[0] * xI;
96+
int64_t newY = originalShape[1] * yI;
9897

9998
SmallVector<int64_t> newShape = {newX, newY};
10099

@@ -113,12 +112,10 @@ struct ConvertAlloc : public OpRewritePattern<memref::AllocOp> {
113112
.getResult();
114113

115114
// Compute the offsets in SLM chunk for the current thread
116-
auto oneConst =
117-
rewriter.create<arith::ConstantIndexOp>(allocOp.getLoc(), 1);
118-
auto origXConst = rewriter.create<arith::ConstantIndexOp>(
119-
allocOp.getLoc(), originalMemRefType.getShape()[0]);
120-
auto origYConst = rewriter.create<arith::ConstantIndexOp>(
121-
allocOp.getLoc(), originalMemRefType.getShape()[1]);
115+
auto origXConst = rewriter.create<arith::ConstantIndexOp>(allocOp.getLoc(),
116+
originalShape[0]);
117+
auto origYConst = rewriter.create<arith::ConstantIndexOp>(allocOp.getLoc(),
118+
originalShape[1]);
122119

123120
auto threadIds = launchOp.getThreadIds();
124121

@@ -133,7 +130,7 @@ struct ConvertAlloc : public OpRewritePattern<memref::AllocOp> {
133130

134131
auto offsets = getMixedValues({ShapedType::kDynamic, ShapedType::kDynamic},
135132
{offX, offY}, rewriter);
136-
auto sizes = getMixedValues(originalMemRefType.getShape(), {}, rewriter);
133+
auto sizes = getMixedValues(originalShape, {}, rewriter);
137134
auto strides = getMixedValues({1, 1}, {}, rewriter);
138135

139136
auto newSlice =

0 commit comments

Comments
 (0)