@@ -120,6 +120,73 @@ namespace {
120120// / algorithm described above.
121121// /
122122class ShapeCastOpRewritePattern : public OpRewritePattern <vector::ShapeCastOp> {
123+
124+ // Case (i) of description.
125+ // Assumes source and result shapes are identical up to some leading ones.
126+ static LogicalResult leadingOnesLowering (vector::ShapeCastOp shapeCast,
127+ PatternRewriter &rewriter) {
128+
129+ const Location loc = shapeCast.getLoc ();
130+ const VectorType sourceType = shapeCast.getSourceVectorType ();
131+ const VectorType resultType = shapeCast.getResultVectorType ();
132+
133+ const int64_t sourceRank = sourceType.getRank ();
134+ const int64_t resultRank = resultType.getRank ();
135+ const int64_t delta = sourceRank - resultRank;
136+ const int64_t sourceLeading = delta > 0 ? delta : 0 ;
137+ const int64_t resultLeading = delta > 0 ? 0 : -delta;
138+
139+ const Value source = shapeCast.getSource ();
140+ const Value poison = rewriter.create <ub::PoisonOp>(loc, resultType);
141+ const Value extracted = rewriter.create <vector::ExtractOp>(
142+ loc, source, SmallVector<int64_t >(sourceLeading, 0 ));
143+ const Value result = rewriter.create <vector::InsertOp>(
144+ loc, extracted, poison, SmallVector<int64_t >(resultLeading, 0 ));
145+
146+ rewriter.replaceOp (shapeCast, result);
147+ return success ();
148+ }
149+
150+ // Case (ii) of description.
151+ // Assumes a shape_cast where the suffix shape of the source starting at
152+ // `sourceDim` and the suffix shape of the result starting at `resultDim` are
153+ // identical.
154+ static LogicalResult noStridedSliceLowering (vector::ShapeCastOp shapeCast,
155+ int64_t sourceDim,
156+ int64_t resultDim,
157+ PatternRewriter &rewriter) {
158+
159+ const Location loc = shapeCast.getLoc ();
160+
161+ const Value source = shapeCast.getSource ();
162+ const ArrayRef<int64_t > sourceShape =
163+ shapeCast.getSourceVectorType ().getShape ();
164+
165+ const VectorType resultType = shapeCast.getResultVectorType ();
166+ const ArrayRef<int64_t > resultShape = resultType.getShape ();
167+
168+ const int64_t nSlices =
169+ std::accumulate (sourceShape.begin (), sourceShape.begin () + sourceDim, 1 ,
170+ std::multiplies<int64_t >());
171+
172+ SmallVector<int64_t > extractIndex (sourceDim, 0 );
173+ SmallVector<int64_t > insertIndex (resultDim, 0 );
174+ Value result = rewriter.create <ub::PoisonOp>(loc, resultType);
175+
176+ for (int i = 0 ; i < nSlices; ++i) {
177+ Value extracted =
178+ rewriter.create <vector::ExtractOp>(loc, source, extractIndex);
179+
180+ result = rewriter.create <vector::InsertOp>(loc, extracted, result,
181+ insertIndex);
182+
183+ inplaceAdd (1 , sourceShape.take_front (sourceDim), extractIndex);
184+ inplaceAdd (1 , resultShape.take_front (resultDim), insertIndex);
185+ }
186+ rewriter.replaceOp (shapeCast, result);
187+ return success ();
188+ }
189+
123190public:
124191 using OpRewritePattern::OpRewritePattern;
125192
@@ -163,18 +230,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
163230 // This is the case (i) where there are just some leading ones to contend
164231 // with in the source or result. It can be handled with a single
165232 // extract/insert pair.
166- if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0 ) {
167- const int64_t delta = sourceRank - resultRank;
168- const int64_t sourceLeading = delta > 0 ? delta : 0 ;
169- const int64_t resultLeading = delta > 0 ? 0 : -delta;
170- const Value poison = rewriter.create <ub::PoisonOp>(loc, resultType);
171- const Value extracted = rewriter.create <vector::ExtractOp>(
172- loc, source, SmallVector<int64_t >(sourceLeading, 0 ));
173- const Value result = rewriter.create <vector::InsertOp>(
174- loc, extracted, poison, SmallVector<int64_t >(resultLeading, 0 ));
175- rewriter.replaceOp (op, result);
176- return success ();
177- }
233+ if (resultSuffixStartDim < 0 || sourceSuffixStartDim < 0 )
234+ return leadingOnesLowering (op, rewriter);
178235
179236 const int64_t sourceSuffixStartDimSize =
180237 sourceType.getDimSize (sourceSuffixStartDim);
@@ -200,27 +257,9 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
200257 // IR is generated in this case if we just extract and insert the elements
201258 // directly. In other words, we don't use extract_strided_slice and
202259 // insert_strided_slice.
203- if (greatestCommonDivisor == 1 ) {
204- sourceSuffixStartDim += 1 ;
205- resultSuffixStartDim += 1 ;
206- SmallVector<int64_t > extractIndex (sourceSuffixStartDim, 0 );
207- SmallVector<int64_t > insertIndex (resultSuffixStartDim, 0 );
208- Value result = rewriter.create <ub::PoisonOp>(loc, resultType);
209- for (size_t i = 0 ; i < nAtomicSlices; ++i) {
210- Value extracted =
211- rewriter.create <vector::ExtractOp>(loc, source, extractIndex);
212-
213- result = rewriter.create <vector::InsertOp>(loc, extracted, result,
214- insertIndex);
215-
216- inplaceAdd (1 , sourceShape.take_front (sourceSuffixStartDim),
217- extractIndex);
218- inplaceAdd (1 , resultShape.take_front (resultSuffixStartDim),
219- insertIndex);
220- }
221- rewriter.replaceOp (op, result);
222- return success ();
223- }
260+ if (greatestCommonDivisor == 1 )
261+ return noStridedSliceLowering (op, sourceSuffixStartDim + 1 ,
262+ resultSuffixStartDim + 1 , rewriter);
224263
225264 // The insert_strided_slice result's type
226265 const ArrayRef<int64_t > insertStridedShape =
0 commit comments