@@ -834,11 +834,94 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
834834 vector::UnrollVectorOptions options;
835835};
836836
837+ // / Takes a 1 dimensional `vector.to_element` op and attempts to change it to
838+ // / the target shape.
839+ // /
840+ // / ```
841+ // / // In SPIR-V's default environment vector of size 8
842+ // / // are not allowed.
843+ // / %elements:8 = vector.to_elements %v : vector<8xf32>
844+ // /
845+ // / ===>
846+ // /
847+ // / %v_0_to_3 = vector.extract %v[0] : vector<4xf32> from vector<8xf32>
848+ // / %v_4_to_7 = vector.extract %v[4] : vector<4xf32> from vector<8xf32>
849+ // / %elements_0:4 = vector.to_elements %v_0_to_3 : vector<4xf32>
850+ // / %elements_1:4 = vector.to_elements %v_4_to_7 : vector<4xf32>
851+ // / ```
852+ // /
853+ // / This pattern may fail if the rank is not divisible by to a native shape
854+ // / or if the rank is already in the target shape and therefore it may be
855+ // / skipped.
856+ struct ToElementsToTargetShape final
857+ : public OpRewritePattern<vector::ToElementsOp> {
858+ ToElementsToTargetShape (MLIRContext *context,
859+ const vector::UnrollVectorOptions &options,
860+ PatternBenefit benefit = 1 )
861+ : OpRewritePattern<vector::ToElementsOp>(context, benefit),
862+ options (options) {}
863+
864+ LogicalResult matchAndRewrite (vector::ToElementsOp op,
865+ PatternRewriter &rewriter) const override {
866+ auto targetShape = getTargetShape (options, op);
867+ if (!targetShape)
868+ return failure ();
869+
870+ // We have
871+ // source_rank = N * target_rank
872+ int64_t source_rank = op.getSourceVectorType ().getShape ().front ();
873+ int64_t target_rank = targetShape->front ();
874+ int64_t N = source_rank / target_rank;
875+
876+ // Transformation where
877+ // s = source_rank and
878+ // t = target_rank
879+ // ```
880+ // %e:s = vector.to_elements %v : vector<sxf32>
881+ //
882+ // ===>
883+ //
884+ // // N vector.extract_strided_slice of size t
885+ // %v0 = vector.extract_strided_slice %v {offsets = [0*t], sizes = [t], strides = [1]} : vector<txf32> from vector<sxf32>
886+ // %v1 = vector.extract_strided_slice %v {offsets = [1*t], sizes = [t], strides = [1]} : vector<txf32> from vector<sxf32>
887+ // ...
888+ // %vNminus1 = vector.extract_strided_slice $v {offsets = [(N-1)*t], sizes = [t], strides = [1]} : vector<txf32> from vector<sxf32>
889+ //
890+ // // N vector.to_elements of size t vectors.
891+ // %e0:t = vector.to_elements %v0 : vector<txf32>
892+ // %e1:t = vector.to_elements %v1 : vector<txf32>
893+ // ...
894+ // %eNminus1:t = vector.to_elements %vNminus1 : vector<txf32>
895+ // ```
896+ SmallVector<Value> subVectors;
897+ SmallVector<int64_t > strides (targetShape->size (), 1 );
898+ for (int64_t i = 0 ; i < N; i++) {
899+ SmallVector<int64_t > elementOffsets = {i * target_rank};
900+ Value subVector = rewriter.createOrFold <vector::ExtractStridedSliceOp>(
901+ op.getLoc (), op.getSource (), elementOffsets, *targetShape, strides);
902+ subVectors.push_back (subVector);
903+ }
904+
905+ SmallVector<Value> elements;
906+ for (const Value subVector : subVectors) {
907+ auto elementsOp =
908+ vector::ToElementsOp::create (rewriter, op.getLoc (), subVector);
909+ llvm::append_range (elements, elementsOp.getResults ());
910+ }
911+
912+ rewriter.replaceOp (op, elements);
913+ return success ();
914+ }
915+
916+ private:
917+ vector::UnrollVectorOptions options;
918+ };
919+
837920// / Unrolls 2 or more dimensional `vector.to_elements` ops by unrolling the
838921// / outermost dimension of the operand. For example:
839922// /
840923// / ```
841- // / %0:4 = vector.to_elements %v : vector<2x2xf32 >
924+ // / %0:8 = vector.to_elements %v : vector<2x2x2xf32 >
842925// /
843926// / ==>
844927// /
@@ -865,6 +948,7 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
865948 FailureOr<SmallVector<Value>> result =
866949 vector::unrollVectorValue (source, rewriter);
867950 if (failed (result)) {
951+ // Only fails if operand is 1-dimensional.
868952 return failure ();
869953 }
870954 SmallVector<Value> vectors = *result;
@@ -1013,8 +1097,8 @@ void mlir::vector::populateVectorUnrollPatterns(
10131097 UnrollReductionPattern, UnrollMultiReductionPattern,
10141098 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
10151099 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1016- UnrollToElements, UnrollStepPattern>(patterns. getContext (),
1017- options, benefit);
1100+ UnrollToElements, UnrollStepPattern, ToElementsToTargetShape>(
1101+ patterns. getContext (), options, benefit);
10181102}
10191103
10201104void mlir::vector::populateVectorToElementsUnrollPatterns (
0 commit comments