@@ -551,9 +551,10 @@ enum class Conv1DOpOrder {
551551 Nwc // Corresponds to operation that traverses the input in (n, w, c) order.
552552};
553553
554- // / Helper data structure to represent the result of vectorization.
555- // / In certain specific cases, like terminators, we do not want to propagate/
556- enum VectorizationStatus {
554+ // / Helper data structure to represent the result of vectorization for a single
555+ // / operation. In certain specific cases, like terminators, we do not want to
556+ // / propagate.
557+ enum VectorizationHookStatus {
557558 // / Op failed to vectorize.
558559 Failure = 0 ,
559560 // / Op vectorized and custom function took care of replacement logic
@@ -564,9 +565,12 @@ enum VectorizationStatus {
564565 // TODO: support values if Op vectorized to Many-Ops whose results we need to
565566 // aggregate for replacement.
566567};
567- struct VectorizationResult {
568+ // / VectorizationHookResult contains the vectorized op returned from a
569+ // / CustomVectorizationHook. This is an internal implementation detail of
570+ // / linalg vectorization, not to be confused with VectorizationResult.
571+ struct VectorizationHookResult {
568572 // / Return status from vectorizing the current op.
569- enum VectorizationStatus status = VectorizationStatus ::Failure;
573+ enum VectorizationHookStatus status = VectorizationHookStatus ::Failure;
570574 // / New vectorized operation to replace the current op.
571575 // / Replacement behavior is specified by `status`.
572576 Operation *newOp;
@@ -728,22 +732,22 @@ using CustomVectorizationPrecondition =
728732// assuming all its vectorized operands are already in the IRMapping.
729733// Return nullptr if the Operation cannot be vectorized.
730734using CustomVectorizationHook =
731- std::function<VectorizationResult (Operation *, const IRMapping &)>;
735+ std::function<VectorizationHookResult (Operation *, const IRMapping &)>;
732736
733737// / Helper function to vectorize the terminator of a `linalgOp`. New result
734738// / vector values are appended to `newResults`. Return
735- // / VectorizationStatus ::NoReplace to signal the vectorization algorithm that it
736- // / should not try to map produced operations and instead return the results
737- // / using the `newResults` vector making them available to the vectorization
738- // / algorithm for RAUW. This function is meant to be used as a
739+ // / VectorizationHookStatus ::NoReplace to signal the vectorization algorithm
740+ // / that it should not try to map produced operations and instead return the
741+ // / results using the `newResults` vector making them available to the
742+ // / vectorization algorithm for RAUW. This function is meant to be used as a
739743// / CustomVectorizationHook.
740- static VectorizationResult
744+ static VectorizationHookResult
741745vectorizeLinalgYield (RewriterBase &rewriter, Operation *op,
742746 const IRMapping &bvm, VectorizationState &state,
743747 LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
744748 auto yieldOp = dyn_cast<linalg::YieldOp>(op);
745749 if (!yieldOp)
746- return VectorizationResult{VectorizationStatus ::Failure, nullptr };
750+ return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
747751 for (const auto &output : llvm::enumerate (yieldOp.getValues ())) {
748752 // TODO: Scan for an opportunity for reuse.
749753 // TODO: use a map.
@@ -755,20 +759,20 @@ vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
755759 newResults.push_back (newResult);
756760 }
757761
758- return VectorizationResult{VectorizationStatus ::NoReplace, nullptr };
762+ return VectorizationHookResult{VectorizationHookStatus ::NoReplace, nullptr };
759763}
760764
761765// / Helper function to vectorize the index operations of a `linalgOp`. Return
762- // / VectorizationStatus ::NewOp to signal the vectorization algorithm that it
766+ // / VectorizationHookStatus ::NewOp to signal the vectorization algorithm that it
763767// / should map the produced operations. This function is meant to be used as a
764768// / CustomVectorizationHook.
765- static VectorizationResult vectorizeLinalgIndex (RewriterBase &rewriter,
766- VectorizationState &state,
767- Operation *op,
768- LinalgOp linalgOp) {
769+ static VectorizationHookResult vectorizeLinalgIndex (RewriterBase &rewriter,
770+ VectorizationState &state,
771+ Operation *op,
772+ LinalgOp linalgOp) {
769773 IndexOp indexOp = dyn_cast<linalg::IndexOp>(op);
770774 if (!indexOp)
771- return VectorizationResult{VectorizationStatus ::Failure, nullptr };
775+ return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
772776 auto loc = indexOp.getLoc ();
773777 // Compute the static loop sizes of the index op.
774778 ArrayRef<int64_t > targetShape = state.getCanonicalVecShape ();
@@ -782,7 +786,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
782786 // dimension of the iteration space since the vectorization algorithm in this
783787 // case can handle the broadcast.
784788 if (dim == targetShape.size () - 1 )
785- return VectorizationResult{VectorizationStatus ::NewOp, indexSteps};
789+ return VectorizationHookResult{VectorizationHookStatus ::NewOp, indexSteps};
786790 // Otherwise permute the targetShape to move the index dimension last,
787791 // broadcast the one-dimensional index vector to the permuted shape, and
788792 // finally transpose the broadcasted index vector to undo the permutation.
@@ -800,7 +804,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
800804 std::swap (transposition.back (), transposition[dim]);
801805 auto transposeOp =
802806 rewriter.create <vector::TransposeOp>(loc, broadCastOp, transposition);
803- return VectorizationResult{VectorizationStatus ::NewOp, transposeOp};
807+ return VectorizationHookResult{VectorizationHookStatus ::NewOp, transposeOp};
804808}
805809
806810// / Helper function to check if the tensor.extract can be vectorized by the
@@ -1098,15 +1102,15 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
10981102}
10991103
11001104// / Helper function to vectorize the tensor.extract operations. Returns
1101- // / VectorizationStatus ::NewOp to signal the vectorization algorithm that it
1105+ // / VectorizationHookStatus ::NewOp to signal the vectorization algorithm that it
11021106// / should map the produced operations. This function is meant to be used as a
11031107// / CustomVectorizationHook.
1104- static VectorizationResult
1108+ static VectorizationHookResult
11051109vectorizeTensorExtract (RewriterBase &rewriter, VectorizationState &state,
11061110 Operation *op, LinalgOp linalgOp, const IRMapping &bvm) {
11071111 tensor::ExtractOp extractOp = dyn_cast<tensor::ExtractOp>(op);
11081112 if (!extractOp)
1109- return VectorizationResult{VectorizationStatus ::Failure, nullptr };
1113+ return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
11101114 auto loc = extractOp.getLoc ();
11111115
11121116 // Compute the static loop sizes of the extract op.
@@ -1138,7 +1142,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
11381142 gatherOp = state.maskOperation (rewriter, gatherOp, linalgOp);
11391143
11401144 LDBG (" Vectorised as gather load: " << extractOp << " \n " );
1141- return VectorizationResult{VectorizationStatus ::NewOp, gatherOp};
1145+ return VectorizationHookResult{VectorizationHookStatus ::NewOp, gatherOp};
11421146 }
11431147
11441148 // 2. Handle:
@@ -1202,7 +1206,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12021206 mlir::vector::maskOperation (rewriter, transferReadOp, allTrue);
12031207
12041208 LDBG (" Vectorised as scalar broadcast load: " << extractOp << " \n " );
1205- return VectorizationResult{VectorizationStatus::NewOp, maskedReadOp};
1209+ return VectorizationHookResult{VectorizationHookStatus::NewOp,
1210+ maskedReadOp};
12061211 }
12071212
12081213 // 2b. Handle contiguous access.
@@ -1228,7 +1233,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
12281233 inBounds);
12291234
12301235 LDBG (" Vectorised as contiguous load: " << extractOp);
1231- return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
1236+ return VectorizationHookResult{VectorizationHookStatus::NewOp,
1237+ transferReadOp};
12321238}
12331239
12341240// / Emit reduction operations if the shapes of the value to reduce is different
@@ -1268,9 +1274,9 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
12681274// / This function assumes all operands of `op` have been vectorized and are in
12691275// / the `bvm` mapping. As a consequence, this function is meant to be called on
12701276// / a topologically-sorted list of ops.
1271- // / This function does not update `bvm` but returns a VectorizationStatus that
1272- // / instructs the caller what `bvm` update needs to occur.
1273- static VectorizationResult
1277+ // / This function does not update `bvm` but returns a VectorizationHookStatus
1278+ // / that instructs the caller what `bvm` update needs to occur.
1279+ static VectorizationHookResult
12741280vectorizeOneOp (RewriterBase &rewriter, VectorizationState &state,
12751281 LinalgOp linalgOp, Operation *op, const IRMapping &bvm,
12761282 ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
@@ -1279,8 +1285,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12791285 // 1. Try to apply any CustomVectorizationHook.
12801286 if (!customVectorizationHooks.empty ()) {
12811287 for (auto &customFunc : customVectorizationHooks) {
1282- VectorizationResult result = customFunc (op, bvm);
1283- if (result.status == VectorizationStatus ::Failure)
1288+ VectorizationHookResult result = customFunc (op, bvm);
1289+ if (result.status == VectorizationHookStatus ::Failure)
12841290 continue ;
12851291 return result;
12861292 }
@@ -1289,11 +1295,12 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
12891295 // 2. Constant ops don't get vectorized but rather broadcasted at their users.
12901296 // Clone so that the constant is not confined to the linalgOp block .
12911297 if (isa<arith::ConstantOp, func::ConstantOp>(op))
1292- return VectorizationResult{VectorizationStatus::NewOp, rewriter.clone (*op)};
1298+ return VectorizationHookResult{VectorizationHookStatus::NewOp,
1299+ rewriter.clone (*op)};
12931300
12941301 // 3. Only ElementwiseMappable are allowed in the generic vectorization.
12951302 if (!OpTrait::hasElementwiseMappableTraits (op))
1296- return VectorizationResult{VectorizationStatus ::Failure, nullptr };
1303+ return VectorizationHookResult{VectorizationHookStatus ::Failure, nullptr };
12971304
12981305 // 4 . Check if the operation is a reduction.
12991306 SmallVector<std::pair<Value, Value>> reductionOperands;
@@ -1316,7 +1323,7 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13161323 reduceIfNeeded (rewriter, linalgOp, op, reductionOperands[0 ].first ,
13171324 reductionOperands[0 ].second , bvm);
13181325 if (reduceOp)
1319- return VectorizationResult{VectorizationStatus ::NewOp, reduceOp};
1326+ return VectorizationHookResult{VectorizationHookStatus ::NewOp, reduceOp};
13201327 }
13211328
13221329 // 5. Generic vectorization path for ElementwiseMappable ops.
@@ -1356,8 +1363,8 @@ vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state,
13561363 : resultType);
13571364 }
13581365 // d. Build and return the new op.
1359- return VectorizationResult {
1360- VectorizationStatus ::NewOp,
1366+ return VectorizationHookResult {
1367+ VectorizationHookStatus ::NewOp,
13611368 rewriter.create (op->getLoc (), op->getName ().getIdentifier (), vecOperands,
13621369 resultTypes, op->getAttrs ())};
13631370}
@@ -1461,34 +1468,34 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
14611468 SmallVector<CustomVectorizationHook> hooks;
14621469 // 4a. Register CustomVectorizationHook for yieldOp.
14631470 CustomVectorizationHook vectorizeYield =
1464- [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1471+ [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
14651472 return vectorizeLinalgYield (rewriter, op, bvm, state, linalgOp, newResults);
14661473 };
14671474 hooks.push_back (vectorizeYield);
14681475
14691476 // 4b. Register CustomVectorizationHook for indexOp.
14701477 CustomVectorizationHook vectorizeIndex =
1471- [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1478+ [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
14721479 return vectorizeLinalgIndex (rewriter, state, op, linalgOp);
14731480 };
14741481 hooks.push_back (vectorizeIndex);
14751482
14761483 // 4c. Register CustomVectorizationHook for extractOp.
14771484 CustomVectorizationHook vectorizeExtract =
1478- [&](Operation *op, const IRMapping &bvm) -> VectorizationResult {
1485+ [&](Operation *op, const IRMapping &bvm) -> VectorizationHookResult {
14791486 return vectorizeTensorExtract (rewriter, state, op, linalgOp, bvm);
14801487 };
14811488 hooks.push_back (vectorizeExtract);
14821489
14831490 // 5. Iteratively call `vectorizeOneOp` to each op in the slice.
14841491 for (Operation &op : block->getOperations ()) {
1485- VectorizationResult result =
1492+ VectorizationHookResult result =
14861493 vectorizeOneOp (rewriter, state, linalgOp, &op, bvm, hooks);
1487- if (result.status == VectorizationStatus ::Failure) {
1494+ if (result.status == VectorizationHookStatus ::Failure) {
14881495 LDBG (" failed to vectorize: " << op << " \n " );
14891496 return failure ();
14901497 }
1491- if (result.status == VectorizationStatus ::NewOp) {
1498+ if (result.status == VectorizationHookStatus ::NewOp) {
14921499 Operation *maybeMaskedOp =
14931500 state.maskOperation (rewriter, result.newOp , linalgOp);
14941501 LDBG (" New vector op: " << *maybeMaskedOp << " \n " );
@@ -2525,17 +2532,11 @@ bool mlir::linalg::hasVectorizationImpl(Operation *op) {
25252532 tensor::InsertSliceOp>(op);
25262533}
25272534
2528- // / Emit a suitable vector form for an operation. If provided,
2529- // / `inputVectorSizes` are used to vectorize this operation.
2530- // / `inputVectorSizes` must match the rank of the iteration space of the
2531- // / operation and the input vector sizes must be greater than or equal to
2532- // / their counterpart iteration space sizes, if static. `inputVectorShapes`
2533- // / also allows the vectorization of operations with dynamic shapes.
2534- LogicalResult mlir::linalg::vectorize (RewriterBase &rewriter, Operation *op,
2535- ArrayRef<int64_t > inputVectorSizes,
2536- ArrayRef<bool > inputScalableVecDims,
2537- bool vectorizeNDExtract,
2538- bool flatten1DDepthwiseConv) {
2535+ FailureOr<VectorizationResult>
2536+ mlir::linalg::vectorize (RewriterBase &rewriter, Operation *op,
2537+ ArrayRef<int64_t > inputVectorSizes,
2538+ ArrayRef<bool > inputScalableVecDims,
2539+ bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
25392540 LDBG (" Attempting to vectorize:\n " << *op << " \n " );
25402541 LDBG (" Input vector sizes: " );
25412542 LLVM_DEBUG (llvm::interleaveComma (inputVectorSizes, llvm::dbgs ()));
@@ -2617,12 +2618,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
26172618 return failure ();
26182619 }
26192620
2620- if (!results.empty ())
2621- rewriter.replaceOp (op, results);
2622- else
2623- rewriter.eraseOp (op);
2624-
2625- return success ();
2621+ return VectorizationResult{results};
26262622}
26272623
26282624LogicalResult mlir::linalg::vectorizeCopy (RewriterBase &rewriter,
0 commit comments