Skip to content

Commit b859a9f

Browse files
authored
[mlir][ArmNeon] Update LowerContractionToSMMLAPattern to support proper unrolling for k dimension (llvm#88591)
Fixes correctness issue with current smmla unrolling patterns whereby unrolling K dimension would only include the result from the last tile along K. Updates patterns to feed previous smmla output of the previous tile into the next one along K.
1 parent f4c0c40 commit b859a9f

File tree

2 files changed

+118
-23
lines changed

2 files changed

+118
-23
lines changed

mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ class LowerContractionToSMMLAPattern
133133
smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
134134
loopOrder.push_back(2);
135135
}
136+
137+
// Keep track of the previous accumulator when tiling over K.
138+
Value kAcc;
136139
for (SmallVector<int64_t> offsets :
137140
StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
138141
// Helper to compute the new shape of each operand and extract the slice.
@@ -194,19 +197,26 @@ class LowerContractionToSMMLAPattern
194197
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
195198
auto collapsedOutputType =
196199
VectorType::get(outputExpandedType.getNumElements(), accElementType);
197-
auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
198-
tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
200+
201+
bool initialKAcc = offsets.back() == 0;
202+
Value collapsedRes;
203+
if (!initialKAcc) {
204+
collapsedRes = kAcc;
205+
} else {
206+
collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
207+
tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
208+
}
199209

200210
// Insert contract op
201-
auto smmlaOp = rewriter.createOrFold<arm_neon::SmmlaOp>(
211+
kAcc = rewriter.createOrFold<arm_neon::SmmlaOp>(
202212
op.getLoc(), collapsedRes.getType(), collapsedRes, collapsedLhs,
203213
collapsedRhs);
204214

205215
// Reshape output back to 2D
206216
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
207-
smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
217+
kAcc.getLoc(), tiledAcc.getType(), kAcc);
208218

209-
// With vecmat, only one row of tiled ACC can be inserted inot file result
219+
// With vecmat, only one row of tiled ACC can be inserted into file result
210220
if (isVecmat) {
211221
tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
212222
}

0 commit comments

Comments
 (0)