@@ -177,26 +177,6 @@ static LogicalResult validateDataTypes(Operation *op,
177177 return success ();
178178}
179179
180- // / TODO(hanchung): Delete the pattern once it is upstreamed:
181- // / https://github.com/llvm/llvm-project/pull/156992
182- struct LowerToElementsPattern : public OpRewritePattern <vector::ToElementsOp> {
183- using OpRewritePattern::OpRewritePattern;
184- LogicalResult matchAndRewrite (vector::ToElementsOp op,
185- PatternRewriter &rewriter) const override {
186- VectorType vecType = op.getSource ().getType ();
187- if (vecType.getRank () == 1 || vecType.getNumScalableDims () > 0 ) {
188- return failure ();
189- }
190- auto vec1DType =
191- VectorType::get ({vecType.getNumElements ()}, vecType.getElementType ());
192- Value shapeCast = vector::ShapeCastOp::create (rewriter, op.getLoc (),
193- vec1DType, op.getSource ());
194- rewriter.replaceOpWithNewOp <vector::ToElementsOp>(op, op.getResultTypes (),
195- shapeCast);
196- return success ();
197- }
198- };
199-
200180// / A pass that replaces all occurrences of GPU device operations with their
201181// / corresponding ROCDL equivalent.
202182// /
@@ -281,7 +261,9 @@ struct ConvertToROCDLPass final
281261 vector::populateVectorInterleaveToShufflePatterns (patterns);
282262 vector::populateVectorContractLoweringPatterns (
283263 patterns, options.vectorContractLowering );
264+
284265 vector::populateVectorFromElementsLoweringPatterns (patterns);
266+ vector::populateVectorToElementsLoweringPatterns (patterns);
285267 vector::populateVectorGatherLoweringPatterns (patterns);
286268 vector::populateVectorMaskOpLoweringPatterns (patterns);
287269 // We currently always use 64 bit indices, thus ensure the bit width of
@@ -295,7 +277,6 @@ struct ConvertToROCDLPass final
295277 patterns, options.vectorTransposeLowering );
296278 vector::populateVectorTransferLoweringPatterns (patterns);
297279 arith::populateExpandBFloat16Patterns (patterns);
298- patterns.insert <LowerToElementsPattern>(&getContext ());
299280 if (failed (applyPatternsGreedily (m, std::move (patterns)))) {
300281 return signalPassFailure ();
301282 }
0 commit comments