Skip to content

Commit 6ff1234

Browse files
committed
Unroll instead of linearize vector.to_elements.
After: * llvm/llvm-project#157740 * llvm/llvm-project#157142 the linearization of vector.to_elements pattern can be changed to either the one now upstream or to the unrolling version. This commit changes the strategy from linearizing to unrolling. Signed-off-by: Erick Ochoa <[email protected]>
1 parent 2b95846 commit 6ff1234

File tree

1 file changed

+2
-21
lines changed

1 file changed

+2
-21
lines changed

compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -177,26 +177,6 @@ static LogicalResult validateDataTypes(Operation *op,
177177
return success();
178178
}
179179

180-
/// TODO(hanchung): Delete the pattern once it is upstreamed:
181-
/// https://github.com/llvm/llvm-project/pull/156992
182-
struct LowerToElementsPattern : public OpRewritePattern<vector::ToElementsOp> {
183-
using OpRewritePattern::OpRewritePattern;
184-
LogicalResult matchAndRewrite(vector::ToElementsOp op,
185-
PatternRewriter &rewriter) const override {
186-
VectorType vecType = op.getSource().getType();
187-
if (vecType.getRank() == 1 || vecType.getNumScalableDims() > 0) {
188-
return failure();
189-
}
190-
auto vec1DType =
191-
VectorType::get({vecType.getNumElements()}, vecType.getElementType());
192-
Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
193-
vec1DType, op.getSource());
194-
rewriter.replaceOpWithNewOp<vector::ToElementsOp>(op, op.getResultTypes(),
195-
shapeCast);
196-
return success();
197-
}
198-
};
199-
200180
/// A pass that replaces all occurrences of GPU device operations with their
201181
/// corresponding ROCDL equivalent.
202182
///
@@ -281,7 +261,9 @@ struct ConvertToROCDLPass final
281261
vector::populateVectorInterleaveToShufflePatterns(patterns);
282262
vector::populateVectorContractLoweringPatterns(
283263
patterns, options.vectorContractLowering);
264+
284265
vector::populateVectorFromElementsLoweringPatterns(patterns);
266+
vector::populateVectorToElementsLoweringPatterns(patterns);
285267
vector::populateVectorGatherLoweringPatterns(patterns);
286268
vector::populateVectorMaskOpLoweringPatterns(patterns);
287269
// We currently always use 64 bit indices, thus ensure the bit width of
@@ -295,7 +277,6 @@ struct ConvertToROCDLPass final
295277
patterns, options.vectorTransposeLowering);
296278
vector::populateVectorTransferLoweringPatterns(patterns);
297279
arith::populateExpandBFloat16Patterns(patterns);
298-
patterns.insert<LowerToElementsPattern>(&getContext());
299280
if (failed(applyPatternsGreedily(m, std::move(patterns)))) {
300281
return signalPassFailure();
301282
}

0 commit comments

Comments
 (0)