@@ -221,7 +221,7 @@ namespace {
221221// / ```
222222// / `vector.matmul` later lowers to `llvm.matrix.multiply`.
223223//
224- // / This only kicks in when VectorTransformsOptions is set to OuterProduct and
224+ // / This only kicks in when VectorTransformsOptions is set to Matmul and
225225// / the vector.contract op is a row-major matrix multiply.
226226class ContractionOpToMatmulOpLowering
227227 : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
@@ -236,11 +236,11 @@ class ContractionOpToMatmulOpLowering
236236 }
237237
238238 ContractionOpToMatmulOpLowering (
239- vector::VectorTransformsOptions vectorTransformOptions ,
239+ vector::VectorContractLowering vectorContractLowering ,
240240 MLIRContext *context, PatternBenefit benefit = 1 ,
241241 FilterConstraintType constraint = defaultFilter)
242242 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
243- vectorTransformOptions (vectorTransformOptions ),
243+ vectorContractLowering (vectorContractLowering ),
244244 filter (std::move(constraint)) {}
245245
246246 FailureOr<Value>
@@ -249,7 +249,7 @@ class ContractionOpToMatmulOpLowering
249249
250250private:
251251 // / Options to control the vector patterns.
252- vector::VectorTransformsOptions vectorTransformOptions ;
252+ vector::VectorContractLowering vectorContractLowering ;
253253 FilterConstraintType filter;
254254};
255255
@@ -281,11 +281,11 @@ class ContractionOpToOuterProductOpLowering
281281 }
282282
283283 ContractionOpToOuterProductOpLowering (
284- vector::VectorTransformsOptions vectorTransformOptions ,
284+ vector::VectorContractLowering vectorContractLowering ,
285285 MLIRContext *context, PatternBenefit benefit = 1 ,
286286 FilterConstraintType constraint = defaultFilter)
287287 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
288- vectorTransformOptions (vectorTransformOptions ),
288+ vectorContractLowering (vectorContractLowering ),
289289 filter (std::move(constraint)) {}
290290
291291 FailureOr<Value>
@@ -294,7 +294,7 @@ class ContractionOpToOuterProductOpLowering
294294
295295private:
296296 // / Options to control the vector patterns.
297- vector::VectorTransformsOptions vectorTransformOptions ;
297+ vector::VectorContractLowering vectorContractLowering ;
298298 FilterConstraintType filter;
299299};
300300
@@ -329,19 +329,19 @@ class ContractionOpToDotLowering
329329 }
330330
331331 ContractionOpToDotLowering (
332- vector::VectorTransformsOptions vectorTransformOptions ,
332+ vector::VectorContractLowering vectorContractLowering ,
333333 MLIRContext *context, PatternBenefit benefit = 1 ,
334334 const FilterConstraintType &constraint = defaultFilter)
335335 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
336- vectorTransformOptions (vectorTransformOptions ), filter(defaultFilter) {}
336+ vectorContractLowering (vectorContractLowering ), filter(defaultFilter) {}
337337
338338 FailureOr<Value>
339339 matchAndRewriteMaskableOp (vector::ContractionOp op, MaskingOpInterface maskOp,
340340 PatternRewriter &rewriter) const override ;
341341
342342private:
343343 // / Options to control the vector patterns.
344- vector::VectorTransformsOptions vectorTransformOptions ;
344+ vector::VectorContractLowering vectorContractLowering ;
345345 FilterConstraintType filter;
346346};
347347
@@ -370,11 +370,12 @@ class ContractionOpLowering
370370 return success ();
371371 }
372372
373- ContractionOpLowering (vector::VectorTransformsOptions vectorTransformOptions,
374- MLIRContext *context, PatternBenefit benefit = 1 ,
375- FilterConstraintType constraint = defaultFilter)
373+ ContractionOpLowering (
374+ vector::VectorContractLowering vectorContractLoweringOption,
375+ MLIRContext *context, PatternBenefit benefit = 1 ,
376+ FilterConstraintType constraint = defaultFilter)
376377 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
377- vectorTransformOptions (vectorTransformOptions ),
378+ vectorContractLoweringOption (vectorContractLoweringOption ),
378379 filter (std::move(constraint)) {}
379380
380381 FailureOr<Value>
@@ -383,7 +384,7 @@ class ContractionOpLowering
383384
384385private:
385386 // / Options to control the vector patterns.
386- vector::VectorTransformsOptions vectorTransformOptions ;
387+ vector::VectorContractLowering vectorContractLoweringOption ;
387388 FilterConstraintType filter;
388389 // Lower one parallel dimension.
389390 FailureOr<Value> lowerParallel (PatternRewriter &rewriter,
@@ -641,8 +642,7 @@ FailureOr<Value>
641642ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp (
642643 vector::ContractionOp op, MaskingOpInterface maskOp,
643644 PatternRewriter &rewriter) const {
644- if (vectorTransformOptions.vectorContractLowering !=
645- vector::VectorContractLowering::OuterProduct)
645+ if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
646646 return failure ();
647647
648648 if (failed (filter (op)))
@@ -672,8 +672,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
672672 if (failed (filter (op)))
673673 return failure ();
674674
675- if (vectorTransformOptions.vectorContractLowering !=
676- vector::VectorContractLowering::Dot)
675+ if (vectorContractLowering != vector::VectorContractLowering::Dot)
677676 return failure ();
678677
679678 auto iteratorTypes = op.getIteratorTypes ().getValue ();
@@ -789,11 +788,11 @@ struct ContractOpToElementwise
789788 return success ();
790789 }
791790 ContractOpToElementwise (
792- vector::VectorTransformsOptions vectorTransformOptions ,
791+ vector::VectorContractLowering vectorContractLowering ,
793792 MLIRContext *context, PatternBenefit benefit = 1 ,
794793 const FilterConstraintType &constraint = defaultFilter)
795794 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
796- vectorTransformOptions (vectorTransformOptions ), filter(defaultFilter) {}
795+ vectorContractLowering (vectorContractLowering ), filter(defaultFilter) {}
797796
798797 FailureOr<Value>
799798 matchAndRewriteMaskableOp (vector::ContractionOp contractOp,
@@ -806,8 +805,7 @@ struct ContractOpToElementwise
806805 if (failed (filter (contractOp)))
807806 return failure ();
808807
809- if (vectorTransformOptions.vectorContractLowering !=
810- vector::VectorContractLowering::ParallelArith)
808+ if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
811809 return failure ();
812810
813811 ArrayRef<int64_t > lhsShape = contractOp.getLhsType ().getShape ();
@@ -898,7 +896,7 @@ struct ContractOpToElementwise
898896
899897private:
900898 // / Options to control the vector patterns.
901- vector::VectorTransformsOptions vectorTransformOptions ;
899+ vector::VectorContractLowering vectorContractLowering ;
902900 FilterConstraintType filter;
903901};
904902
@@ -941,25 +939,25 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
941939 // TODO: implement benefits, cost models.
942940 MLIRContext *ctx = op.getContext ();
943941
944- ContractionOpToMatmulOpLowering pat1 (vectorTransformOptions , ctx);
942+ ContractionOpToMatmulOpLowering pat1 (vectorContractLoweringOption , ctx);
945943 FailureOr<Value> newVal1 =
946944 pat1.matchAndRewriteMaskableOp (op, maskOp, rewriter);
947945 if (!failed (newVal1))
948946 return newVal1;
949947
950- ContractionOpToOuterProductOpLowering pat2 (vectorTransformOptions , ctx);
948+ ContractionOpToOuterProductOpLowering pat2 (vectorContractLoweringOption , ctx);
951949 FailureOr<Value> newVal2 =
952950 pat2.matchAndRewriteMaskableOp (op, maskOp, rewriter);
953951 if (!failed (newVal2))
954952 return newVal2;
955953
956- ContractionOpToDotLowering pat3 (vectorTransformOptions , ctx);
954+ ContractionOpToDotLowering pat3 (vectorContractLoweringOption , ctx);
957955 FailureOr<Value> newVal3 =
958956 pat3.matchAndRewriteMaskableOp (op, maskOp, rewriter);
959957 if (!failed (newVal3))
960958 return newVal3;
961959
962- ContractOpToElementwise pat4 (vectorTransformOptions , ctx);
960+ ContractOpToElementwise pat4 (vectorContractLoweringOption , ctx);
963961 FailureOr<Value> newVal4 =
964962 pat4.matchAndRewriteMaskableOp (op, maskOp, rewriter);
965963 if (!failed (newVal4))
@@ -1292,8 +1290,7 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
12921290 if (maskOp)
12931291 return failure ();
12941292
1295- if (vectorTransformOptions.vectorContractLowering !=
1296- vector::VectorContractLowering::Matmul)
1293+ if (vectorContractLowering != vector::VectorContractLowering::Matmul)
12971294 return failure ();
12981295 if (failed (filter (op)))
12991296 return failure ();
@@ -1382,13 +1379,14 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
13821379} // namespace
13831380
13841381void mlir::vector::populateVectorContractLoweringPatterns (
1385- RewritePatternSet &patterns, VectorTransformsOptions options,
1386- PatternBenefit benefit, bool disableOuterProductLowering) {
1382+ RewritePatternSet &patterns,
1383+ VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
1384+ bool disableOuterProductLowering) {
13871385 if (!disableOuterProductLowering)
13881386 patterns.add <OuterProductOpLowering>(patterns.getContext (), benefit);
13891387 patterns.add <ContractionOpLowering, ContractionOpToMatmulOpLowering,
13901388 ContractionOpToOuterProductOpLowering>(
1391- options , patterns.getContext (), benefit);
1389+ vectorContractLoweringOption , patterns.getContext (), benefit);
13921390}
13931391
13941392void mlir::vector::populateVectorOuterProductLoweringPatterns (
0 commit comments