@@ -215,13 +215,13 @@ namespace {
215215// / ```
216216// / %flattened_a = vector.shape_cast %a
217217// / %flattened_b = vector.shape_cast %b
218- // / %flattened_d = vector.matmul %flattened_a, %flattened_b
218+ // / %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
219219// / %d = vector.shape_cast %%flattened_d
220220// / %e = add %c, %d
221221// / ```
222- // / `vector.matmul ` later lowers to `llvm.matrix.multiply`.
222+ // / `vector.matrix_multiply ` later lowers to `llvm.matrix.multiply`.
223223//
224- // / This only kicks in when VectorTransformsOptions is set to OuterProduct and
224+ // / This only kicks in when vectorContractLowering is set to Matmul and
225225// / the vector.contract op is a row-major matrix multiply.
226226class ContractionOpToMatmulOpLowering
227227 : public vector::MaskableOpRewritePattern<vector::ContractionOp> {
@@ -236,11 +236,11 @@ class ContractionOpToMatmulOpLowering
236236 }
237237
238238 ContractionOpToMatmulOpLowering (
239- vector::VectorTransformsOptions vectorTransformOptions ,
239+ vector::VectorContractLowering vectorContractLowering ,
240240 MLIRContext *context, PatternBenefit benefit = 1 ,
241241 FilterConstraintType constraint = defaultFilter)
242242 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
243- vectorTransformOptions (vectorTransformOptions ),
243+ vectorContractLowering (vectorContractLowering ),
244244 filter (std::move(constraint)) {}
245245
246246 FailureOr<Value>
@@ -249,7 +249,7 @@ class ContractionOpToMatmulOpLowering
249249
250250private:
251251 // / Options to control the vector patterns.
252- vector::VectorTransformsOptions vectorTransformOptions ;
252+ vector::VectorContractLowering vectorContractLowering ;
253253 FilterConstraintType filter;
254254};
255255
@@ -266,7 +266,7 @@ class ContractionOpToMatmulOpLowering
266266// / %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
267267// / ```
268268// /
269- // / This only kicks in when VectorTransformsOptions is set to OuterProduct and
269+ // / This only kicks in when vectorContractLowering is set to OuterProduct and
270270// / the vector.contract op is a row-major matrix multiply.
271271class ContractionOpToOuterProductOpLowering
272272 : public MaskableOpRewritePattern<vector::ContractionOp> {
@@ -281,11 +281,11 @@ class ContractionOpToOuterProductOpLowering
281281 }
282282
283283 ContractionOpToOuterProductOpLowering (
284- vector::VectorTransformsOptions vectorTransformOptions ,
284+ vector::VectorContractLowering vectorContractLowering ,
285285 MLIRContext *context, PatternBenefit benefit = 1 ,
286286 FilterConstraintType constraint = defaultFilter)
287287 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
288- vectorTransformOptions (vectorTransformOptions ),
288+ vectorContractLowering (vectorContractLowering ),
289289 filter (std::move(constraint)) {}
290290
291291 FailureOr<Value>
@@ -294,7 +294,7 @@ class ContractionOpToOuterProductOpLowering
294294
295295private:
296296 // / Options to control the vector patterns.
297- vector::VectorTransformsOptions vectorTransformOptions ;
297+ vector::VectorContractLowering vectorContractLowering ;
298298 FilterConstraintType filter;
299299};
300300
@@ -329,19 +329,19 @@ class ContractionOpToDotLowering
329329 }
330330
331331 ContractionOpToDotLowering (
332- vector::VectorTransformsOptions vectorTransformOptions ,
332+ vector::VectorContractLowering vectorContractLowering ,
333333 MLIRContext *context, PatternBenefit benefit = 1 ,
334334 const FilterConstraintType &constraint = defaultFilter)
335335 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
336- vectorTransformOptions (vectorTransformOptions ), filter(defaultFilter) {}
336+ vectorContractLowering (vectorContractLowering ), filter(defaultFilter) {}
337337
338338 FailureOr<Value>
339339 matchAndRewriteMaskableOp (vector::ContractionOp op, MaskingOpInterface maskOp,
340340 PatternRewriter &rewriter) const override ;
341341
342342private:
343343 // / Options to control the vector patterns.
344- vector::VectorTransformsOptions vectorTransformOptions ;
344+ vector::VectorContractLowering vectorContractLowering ;
345345 FilterConstraintType filter;
346346};
347347
@@ -370,11 +370,12 @@ class ContractionOpLowering
370370 return success ();
371371 }
372372
373- ContractionOpLowering (vector::VectorTransformsOptions vectorTransformOptions,
374- MLIRContext *context, PatternBenefit benefit = 1 ,
375- FilterConstraintType constraint = defaultFilter)
373+ ContractionOpLowering (
374+ vector::VectorContractLowering vectorContractLoweringOption,
375+ MLIRContext *context, PatternBenefit benefit = 1 ,
376+ FilterConstraintType constraint = defaultFilter)
376377 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
377- vectorTransformOptions (vectorTransformOptions ),
378+ vectorContractLoweringOption (vectorContractLoweringOption ),
378379 filter (std::move(constraint)) {}
379380
380381 FailureOr<Value>
@@ -383,7 +384,7 @@ class ContractionOpLowering
383384
384385private:
385386 // / Options to control the vector patterns.
386- vector::VectorTransformsOptions vectorTransformOptions ;
387+ vector::VectorContractLowering vectorContractLoweringOption ;
387388 FilterConstraintType filter;
388389 // Lower one parallel dimension.
389390 FailureOr<Value> lowerParallel (PatternRewriter &rewriter,
@@ -635,14 +636,13 @@ struct UnrolledOuterProductGenerator
635636// / %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
636637// / ```
637638// /
638- // / This only kicks in when VectorTransformsOptions is set to OuterProduct but
639+ // / This only kicks in when vectorContractLowering is set to OuterProduct but
639640// / otherwise supports any layout permutation of the matrix-multiply.
640641FailureOr<Value>
641642ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp (
642643 vector::ContractionOp op, MaskingOpInterface maskOp,
643644 PatternRewriter &rewriter) const {
644- if (vectorTransformOptions.vectorContractLowering !=
645- vector::VectorContractLowering::OuterProduct)
645+ if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
646646 return failure ();
647647
648648 if (failed (filter (op)))
@@ -672,8 +672,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
672672 if (failed (filter (op)))
673673 return failure ();
674674
675- if (vectorTransformOptions.vectorContractLowering !=
676- vector::VectorContractLowering::Dot)
675+ if (vectorContractLowering != vector::VectorContractLowering::Dot)
677676 return failure ();
678677
679678 auto iteratorTypes = op.getIteratorTypes ().getValue ();
@@ -789,11 +788,11 @@ struct ContractOpToElementwise
789788 return success ();
790789 }
791790 ContractOpToElementwise (
792- vector::VectorTransformsOptions vectorTransformOptions ,
791+ vector::VectorContractLowering vectorContractLowering ,
793792 MLIRContext *context, PatternBenefit benefit = 1 ,
794793 const FilterConstraintType &constraint = defaultFilter)
795794 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
796- vectorTransformOptions (vectorTransformOptions ), filter(defaultFilter) {}
795+ vectorContractLowering (vectorContractLowering ), filter(defaultFilter) {}
797796
798797 FailureOr<Value>
799798 matchAndRewriteMaskableOp (vector::ContractionOp contractOp,
@@ -806,8 +805,7 @@ struct ContractOpToElementwise
806805 if (failed (filter (contractOp)))
807806 return failure ();
808807
809- if (vectorTransformOptions.vectorContractLowering !=
810- vector::VectorContractLowering::ParallelArith)
808+ if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
811809 return failure ();
812810
813811 ArrayRef<int64_t > lhsShape = contractOp.getLhsType ().getShape ();
@@ -898,7 +896,7 @@ struct ContractOpToElementwise
898896
899897private:
900898 // / Options to control the vector patterns.
901- vector::VectorTransformsOptions vectorTransformOptions ;
899+ vector::VectorContractLowering vectorContractLowering ;
902900 FilterConstraintType filter;
903901};
904902
@@ -913,7 +911,7 @@ struct ContractOpToElementwise
913911// / until a pure contraction is reached (no free/batch dimensions),
914912// / which is replaced by a dot-product.
915913// /
916- // / This only kicks in when either VectorTransformsOptions is set
914+ // / This only kicks in when either vectorContractLoweringOption is set
917915// / to DOT or when other contraction patterns fail.
918916//
919917// TODO: break down into transpose/reshape/cast ops
@@ -941,25 +939,25 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
941939 // TODO: implement benefits, cost models.
942940 MLIRContext *ctx = op.getContext ();
943941
944- ContractionOpToMatmulOpLowering pat1 (vectorTransformOptions , ctx);
942+ ContractionOpToMatmulOpLowering pat1 (vectorContractLoweringOption , ctx);
945943 FailureOr<Value> newVal1 =
946944 pat1.matchAndRewriteMaskableOp (op, maskOp, rewriter);
947945 if (!failed (newVal1))
948946 return newVal1;
949947
950- ContractionOpToOuterProductOpLowering pat2 (vectorTransformOptions , ctx);
948+ ContractionOpToOuterProductOpLowering pat2 (vectorContractLoweringOption , ctx);
951949 FailureOr<Value> newVal2 =
952950 pat2.matchAndRewriteMaskableOp (op, maskOp, rewriter);
953951 if (!failed (newVal2))
954952 return newVal2;
955953
956- ContractionOpToDotLowering pat3 (vectorTransformOptions , ctx);
954+ ContractionOpToDotLowering pat3 (vectorContractLoweringOption , ctx);
957955 FailureOr<Value> newVal3 =
958956 pat3.matchAndRewriteMaskableOp (op, maskOp, rewriter);
959957 if (!failed (newVal3))
960958 return newVal3;
961959
962- ContractOpToElementwise pat4 (vectorTransformOptions , ctx);
960+ ContractOpToElementwise pat4 (vectorContractLoweringOption , ctx);
963961 FailureOr<Value> newVal4 =
964962 pat4.matchAndRewriteMaskableOp (op, maskOp, rewriter);
965963 if (!failed (newVal4))
@@ -1273,14 +1271,14 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
12731271// / %mtb = maybe_transpose
12741272// / %flattened_a = vector.shape_cast %mta
12751273// / %flattened_b = vector.shape_cast %mtb
1276- // / %flattened_d = vector.matmul %flattened_a, %flattened_b
1274+ // / %flattened_d = vector.matrix_multiply %flattened_a, %flattened_b
12771275// / %mtd = vector.shape_cast %flattened_d
12781276// / %d = maybe_untranspose %mtd
12791277// / %e = add %c, %d
12801278// / ```
1281- // / `vector.matmul ` later lowers to `llvm.matrix.multiply`.
1279+ // / `vector.matrix_multiply ` later lowers to `llvm.matrix.multiply`.
12821280//
1283- // / This only kicks in when VectorTransformsOptions is set to `Matmul`.
1281+ // / This only kicks in when vectorContractLowering is set to `Matmul`.
12841282// / vector.transpose operations are inserted if the vector.contract op is not a
12851283// / row-major matrix multiply.
12861284// /
@@ -1292,8 +1290,7 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
12921290 if (maskOp)
12931291 return failure ();
12941292
1295- if (vectorTransformOptions.vectorContractLowering !=
1296- vector::VectorContractLowering::Matmul)
1293+ if (vectorContractLowering != vector::VectorContractLowering::Matmul)
12971294 return failure ();
12981295 if (failed (filter (op)))
12991296 return failure ();
@@ -1382,13 +1379,14 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
13821379} // namespace
13831380
13841381void mlir::vector::populateVectorContractLoweringPatterns (
1385- RewritePatternSet &patterns, VectorTransformsOptions options,
1386- PatternBenefit benefit, bool disableOuterProductLowering) {
1382+ RewritePatternSet &patterns,
1383+ VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
1384+ bool disableOuterProductLowering) {
13871385 if (!disableOuterProductLowering)
13881386 patterns.add <OuterProductOpLowering>(patterns.getContext (), benefit);
13891387 patterns.add <ContractionOpLowering, ContractionOpToMatmulOpLowering,
13901388 ContractionOpToOuterProductOpLowering>(
1391- options , patterns.getContext (), benefit);
1389+ vectorContractLoweringOption , patterns.getContext (), benefit);
13921390}
13931391
13941392void mlir::vector::populateVectorOuterProductLoweringPatterns (
0 commit comments