Skip to content

Commit 74ea559

Browse files
committed
pull special cases in logic into separate (private static class) methods
1 parent 125573b commit 74ea559

File tree

1 file changed

+72
-33
lines changed

1 file changed

+72
-33
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp

Lines changed: 72 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,73 @@ namespace {
120120
/// algorithm described above.
121121
///
122122
class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
123+
124+
// Case (i) of description.
125+
// Assumes source and result shapes are identical up to some leading ones.
126+
static LogicalResult leadingOnesLowering(vector::ShapeCastOp shapeCast,
127+
PatternRewriter &rewriter) {
128+
129+
const Location loc = shapeCast.getLoc();
130+
const VectorType sourceType = shapeCast.getSourceVectorType();
131+
const VectorType resultType = shapeCast.getResultVectorType();
132+
133+
const int64_t sourceRank = sourceType.getRank();
134+
const int64_t resultRank = resultType.getRank();
135+
const int64_t delta = sourceRank - resultRank;
136+
const int64_t sourceLeading = delta > 0 ? delta : 0;
137+
const int64_t resultLeading = delta > 0 ? 0 : -delta;
138+
139+
const Value source = shapeCast.getSource();
140+
const Value poison = rewriter.create<ub::PoisonOp>(loc, resultType);
141+
const Value extracted = rewriter.create<vector::ExtractOp>(
142+
loc, source, SmallVector<int64_t>(sourceLeading, 0));
143+
const Value result = rewriter.create<vector::InsertOp>(
144+
loc, extracted, poison, SmallVector<int64_t>(resultLeading, 0));
145+
146+
rewriter.replaceOp(shapeCast, result);
147+
return success();
148+
}
149+
150+
// Case (ii) of description.
151+
// Assumes a shape_cast where the suffix shape of the source starting at
152+
// `sourceDim` and the suffix shape of the result starting at `resultDim` are
153+
// identical.
154+
static LogicalResult noStridedSliceLowering(vector::ShapeCastOp shapeCast,
155+
int64_t sourceDim,
156+
int64_t resultDim,
157+
PatternRewriter &rewriter) {
158+
159+
const Location loc = shapeCast.getLoc();
160+
161+
const Value source = shapeCast.getSource();
162+
const ArrayRef<int64_t> sourceShape =
163+
shapeCast.getSourceVectorType().getShape();
164+
165+
const VectorType resultType = shapeCast.getResultVectorType();
166+
const ArrayRef<int64_t> resultShape = resultType.getShape();
167+
168+
const int64_t nSlices =
169+
std::accumulate(sourceShape.begin(), sourceShape.begin() + sourceDim, 1,
170+
std::multiplies<int64_t>());
171+
172+
SmallVector<int64_t> extractIndex(sourceDim, 0);
173+
SmallVector<int64_t> insertIndex(resultDim, 0);
174+
Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
175+
176+
for (int i = 0; i < nSlices; ++i) {
177+
Value extracted =
178+
rewriter.create<vector::ExtractOp>(loc, source, extractIndex);
179+
180+
result = rewriter.create<vector::InsertOp>(loc, extracted, result,
181+
insertIndex);
182+
183+
inplaceAdd(1, sourceShape.take_front(sourceDim), extractIndex);
184+
inplaceAdd(1, resultShape.take_front(resultDim), insertIndex);
185+
}
186+
rewriter.replaceOp(shapeCast, result);
187+
return success();
188+
}
189+
123190
public:
124191
using OpRewritePattern::OpRewritePattern;
125192

@@ -163,18 +230,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
163230
// This is the case (i) where there are just some leading ones to contend
164231
// with in the source or result. It can be handled with a single
165232
// extract/insert pair.
166-
if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0) {
167-
const int64_t delta = sourceRank - resultRank;
168-
const int64_t sourceLeading = delta > 0 ? delta : 0;
169-
const int64_t resultLeading = delta > 0 ? 0 : -delta;
170-
const Value poison = rewriter.create<ub::PoisonOp>(loc, resultType);
171-
const Value extracted = rewriter.create<vector::ExtractOp>(
172-
loc, source, SmallVector<int64_t>(sourceLeading, 0));
173-
const Value result = rewriter.create<vector::InsertOp>(
174-
loc, extracted, poison, SmallVector<int64_t>(resultLeading, 0));
175-
rewriter.replaceOp(op, result);
176-
return success();
177-
}
233+
if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0)
234+
return leadingOnesLowering(op, rewriter);
178235

179236
const int64_t sourceSuffixStartDimSize =
180237
sourceType.getDimSize(sourceSuffixStartDim);
@@ -200,27 +257,9 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
200257
// IR is generated in this case if we just extract and insert the elements
201258
// directly. In other words, we don't use extract_strided_slice and
202259
// insert_strided_slice.
203-
if (greatestCommonDivisor == 1) {
204-
sourceSuffixStartDim += 1;
205-
resultSuffixStartDim += 1;
206-
SmallVector<int64_t> extractIndex(sourceSuffixStartDim, 0);
207-
SmallVector<int64_t> insertIndex(resultSuffixStartDim, 0);
208-
Value result = rewriter.create<ub::PoisonOp>(loc, resultType);
209-
for (size_t i = 0; i < nAtomicSlices; ++i) {
210-
Value extracted =
211-
rewriter.create<vector::ExtractOp>(loc, source, extractIndex);
212-
213-
result = rewriter.create<vector::InsertOp>(loc, extracted, result,
214-
insertIndex);
215-
216-
inplaceAdd(1, sourceShape.take_front(sourceSuffixStartDim),
217-
extractIndex);
218-
inplaceAdd(1, resultShape.take_front(resultSuffixStartDim),
219-
insertIndex);
220-
}
221-
rewriter.replaceOp(op, result);
222-
return success();
223-
}
260+
if (greatestCommonDivisor == 1)
261+
return noStridedSliceLowering(op, sourceSuffixStartDim + 1,
262+
resultSuffixStartDim + 1, rewriter);
224263

225264
// The insert_strided_slice result's type
226265
const ArrayRef<int64_t> insertStridedShape =

0 commit comments

Comments
 (0)