Skip to content

Commit 6910060

Browse files
committed
Explicitly speciy all vector transform options on ConvertVectorToLLVMPass
Refactor ConvertVectorToLLVMPass options
1 parent a1163d8 commit 6910060

File tree

8 files changed

+101
-68
lines changed

8 files changed

+101
-68
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#define MLIR_CONVERSION_PASSES
1111

1212
include "mlir/Pass/PassBase.td"
13-
13+
include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td"
1414

1515
//===----------------------------------------------------------------------===//
1616
// ToLLVM
@@ -1410,10 +1410,32 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
14101410
"bool", /*default=*/"false",
14111411
"Enables the use of X86Vector dialect while lowering the vector "
14121412
"dialect.">,
1413-
Option<"vectorTransformsOptions", "vector-transform-options",
1414-
"vector::VectorTransformsOptions",
1415-
/*default=*/"vector::VectorTransformsOptions()",
1416-
"Options to lower some operations like contractions and transposes.">,
1413+
Option<"vectorContractLowering", "vector-contract-lowering",
1414+
"vector::VectorContractLowering",
1415+
/*default=*/"vector::VectorContractLowering::Dot",
1416+
VectorContractLoweringAttr.summary, [{::llvm::cl::values(
1417+
clEnumValN(::mlir::vector::VectorContractLowering::Dot, "dot",
1418+
"Progressively lower to finer grained `vector.contract` and dot-products. (default)"),
1419+
clEnumValN(::mlir::vector::VectorContractLowering::Matmul, "matmul",
1420+
"Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics."),
1421+
clEnumValN(::mlir::vector::VectorContractLowering::OuterProduct, "outerproduct",
1422+
"Lower to `vector.outerproduct`."),
1423+
clEnumValN(::mlir::vector::VectorContractLowering::ParallelArith, "parallelarith",
1424+
"Lower contract with all reduction dimensions unrolled to 1 to a vector elementwise operations.")
1425+
)}]>,
1426+
Option<"vectorTransposeLowering", "vector-transpose-lowering",
1427+
"vector::VectorTransposeLowering",
1428+
/*default=*/"vector::VectorTransposeLowering::EltWise",
1429+
VectorTransposeLoweringAttr.summary, [{::llvm::cl::values(
1430+
clEnumValN(::mlir::vector::VectorTransposeLowering::EltWise, "eltwise",
1431+
"Lower transpose into element-wise extract and inserts (default)"),
1432+
clEnumValN(::mlir::vector::VectorTransposeLowering::Flat, "flat",
1433+
"Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix intrinsics"),
1434+
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle1D, "shuffle1d",
1435+
"Lower 2-D transpose to `vector.shuffle` on 1-D vector."),
1436+
clEnumValN(::mlir::vector::VectorTransposeLowering::Shuffle16x16, "shuffle16x16",
1437+
"Lower 2-D transpose to `vector.shuffle` on 16x16 vector.")
1438+
)}]>,
14171439
];
14181440
}
14191441

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
1010
#define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
1111

12+
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
1213
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1314

1415
namespace mlir {
@@ -47,7 +48,8 @@ namespace vector {
4748
/// Progressively lower a `vector.contract` with row-major matmul semantics to
4849
/// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`.
4950
void populateVectorContractLoweringPatterns(
50-
RewritePatternSet &patterns, VectorTransformsOptions options,
51+
RewritePatternSet &patterns,
52+
VectorContractLowering vectorContractLoweringOption,
5153
PatternBenefit benefit = 1, bool disableOuterProductLowering = false);
5254

5355
/// Populate the pattern set with the following patterns:
@@ -142,9 +144,10 @@ void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns,
142144
///
143145
/// [TransposeOp2DToShuffleLowering]
144146
///
145-
void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns,
146-
VectorTransformsOptions options,
147-
PatternBenefit benefit = 1);
147+
void populateVectorTransposeLoweringPatterns(
148+
RewritePatternSet &patterns,
149+
VectorTransposeLowering vectorTransposeLowering,
150+
PatternBenefit benefit = 1);
148151

149152
/// Populate the pattern set with the following patterns:
150153
///

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ void ConvertVectorToLLVMPass::runOnOperation() {
6969
populateVectorToVectorCanonicalizationPatterns(patterns);
7070
populateVectorBitCastLoweringPatterns(patterns);
7171
populateVectorBroadcastLoweringPatterns(patterns);
72-
populateVectorContractLoweringPatterns(patterns, vectorTransformsOptions);
72+
populateVectorContractLoweringPatterns(patterns, vectorContractLowering);
7373
populateVectorMaskOpLoweringPatterns(patterns);
7474
populateVectorShapeCastLoweringPatterns(patterns);
7575
populateVectorInterleaveLoweringPatterns(patterns);
76-
populateVectorTransposeLoweringPatterns(patterns, vectorTransformsOptions);
76+
populateVectorTransposeLoweringPatterns(patterns, vectorTransposeLowering);
7777
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
7878
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
7979
populateVectorMaskMaterializationPatterns(patterns,

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,9 +1374,8 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
13741374
// further transformations to canonicalize/cancel.
13751375
{
13761376
RewritePatternSet patterns(context);
1377-
auto options = vector::VectorTransformsOptions().setVectorTransposeLowering(
1378-
vector::VectorTransposeLowering::EltWise);
1379-
vector::populateVectorTransposeLoweringPatterns(patterns, options);
1377+
vector::populateVectorTransposeLoweringPatterns(
1378+
patterns, vector::VectorTransposeLowering::EltWise);
13801379
vector::populateVectorShapeCastLoweringPatterns(patterns);
13811380
if (failed(applyPatternsGreedily(op, std::move(patterns))))
13821381
return failure();

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,7 @@ void transform::ApplyLowerBroadcastPatternsOp::populatePatterns(
102102

103103
void transform::ApplyLowerContractionPatternsOp::populatePatterns(
104104
RewritePatternSet &patterns) {
105-
vector::VectorTransformsOptions vectorTransformOptions;
106-
vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy());
107-
populateVectorContractLoweringPatterns(patterns, vectorTransformOptions,
105+
populateVectorContractLoweringPatterns(patterns, getLoweringStrategy(),
108106
/*benefit=*/1,
109107
/*disableOuterProductLowering=*/true);
110108
}
@@ -161,9 +159,8 @@ void transform::ApplyLowerTransferPatternsOp::populatePatterns(
161159

162160
void transform::ApplyLowerTransposePatternsOp::populatePatterns(
163161
RewritePatternSet &patterns) {
164-
vector::populateVectorTransposeLoweringPatterns(
165-
patterns, vector::VectorTransformsOptions().setVectorTransposeLowering(
166-
getLoweringStrategy()));
162+
vector::populateVectorTransposeLoweringPatterns(patterns,
163+
getLoweringStrategy());
167164
if (getAvx2LoweringStrategy()) {
168165
auto avx2LoweringOptions =
169166
x86vector::avx2::LoweringOptions().setTransposeOptions(

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ namespace {
221221
/// ```
222222
/// `vector.matmul` later lowers to `llvm.matrix.multiply`.
223223
//
224-
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
224+
/// This only kicks in when VectorTransformsOptions is set to Matmul and
225225
/// the vector.contract op is a row-major matrix multiply.
226226
class ContractionOpToMatmulOpLowering
227227
: public vector::MaskableOpRewritePattern<vector::ContractionOp> {
@@ -236,11 +236,11 @@ class ContractionOpToMatmulOpLowering
236236
}
237237

238238
ContractionOpToMatmulOpLowering(
239-
vector::VectorTransformsOptions vectorTransformOptions,
239+
vector::VectorContractLowering vectorContractLowering,
240240
MLIRContext *context, PatternBenefit benefit = 1,
241241
FilterConstraintType constraint = defaultFilter)
242242
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
243-
vectorTransformOptions(vectorTransformOptions),
243+
vectorContractLowering(vectorContractLowering),
244244
filter(std::move(constraint)) {}
245245

246246
FailureOr<Value>
@@ -249,7 +249,7 @@ class ContractionOpToMatmulOpLowering
249249

250250
private:
251251
/// Options to control the vector patterns.
252-
vector::VectorTransformsOptions vectorTransformOptions;
252+
vector::VectorContractLowering vectorContractLowering;
253253
FilterConstraintType filter;
254254
};
255255

@@ -281,11 +281,11 @@ class ContractionOpToOuterProductOpLowering
281281
}
282282

283283
ContractionOpToOuterProductOpLowering(
284-
vector::VectorTransformsOptions vectorTransformOptions,
284+
vector::VectorContractLowering vectorContractLowering,
285285
MLIRContext *context, PatternBenefit benefit = 1,
286286
FilterConstraintType constraint = defaultFilter)
287287
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
288-
vectorTransformOptions(vectorTransformOptions),
288+
vectorContractLowering(vectorContractLowering),
289289
filter(std::move(constraint)) {}
290290

291291
FailureOr<Value>
@@ -294,7 +294,7 @@ class ContractionOpToOuterProductOpLowering
294294

295295
private:
296296
/// Options to control the vector patterns.
297-
vector::VectorTransformsOptions vectorTransformOptions;
297+
vector::VectorContractLowering vectorContractLowering;
298298
FilterConstraintType filter;
299299
};
300300

@@ -329,19 +329,19 @@ class ContractionOpToDotLowering
329329
}
330330

331331
ContractionOpToDotLowering(
332-
vector::VectorTransformsOptions vectorTransformOptions,
332+
vector::VectorContractLowering vectorContractLowering,
333333
MLIRContext *context, PatternBenefit benefit = 1,
334334
const FilterConstraintType &constraint = defaultFilter)
335335
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
336-
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
336+
vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
337337

338338
FailureOr<Value>
339339
matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
340340
PatternRewriter &rewriter) const override;
341341

342342
private:
343343
/// Options to control the vector patterns.
344-
vector::VectorTransformsOptions vectorTransformOptions;
344+
vector::VectorContractLowering vectorContractLowering;
345345
FilterConstraintType filter;
346346
};
347347

@@ -370,11 +370,12 @@ class ContractionOpLowering
370370
return success();
371371
}
372372

373-
ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions,
374-
MLIRContext *context, PatternBenefit benefit = 1,
375-
FilterConstraintType constraint = defaultFilter)
373+
ContractionOpLowering(
374+
vector::VectorContractLowering vectorContractLoweringOption,
375+
MLIRContext *context, PatternBenefit benefit = 1,
376+
FilterConstraintType constraint = defaultFilter)
376377
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
377-
vectorTransformOptions(vectorTransformOptions),
378+
vectorContractLoweringOption(vectorContractLoweringOption),
378379
filter(std::move(constraint)) {}
379380

380381
FailureOr<Value>
@@ -383,7 +384,7 @@ class ContractionOpLowering
383384

384385
private:
385386
/// Options to control the vector patterns.
386-
vector::VectorTransformsOptions vectorTransformOptions;
387+
vector::VectorContractLowering vectorContractLoweringOption;
387388
FilterConstraintType filter;
388389
// Lower one parallel dimension.
389390
FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
@@ -641,8 +642,7 @@ FailureOr<Value>
641642
ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
642643
vector::ContractionOp op, MaskingOpInterface maskOp,
643644
PatternRewriter &rewriter) const {
644-
if (vectorTransformOptions.vectorContractLowering !=
645-
vector::VectorContractLowering::OuterProduct)
645+
if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
646646
return failure();
647647

648648
if (failed(filter(op)))
@@ -672,8 +672,7 @@ FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
672672
if (failed(filter(op)))
673673
return failure();
674674

675-
if (vectorTransformOptions.vectorContractLowering !=
676-
vector::VectorContractLowering::Dot)
675+
if (vectorContractLowering != vector::VectorContractLowering::Dot)
677676
return failure();
678677

679678
auto iteratorTypes = op.getIteratorTypes().getValue();
@@ -789,11 +788,11 @@ struct ContractOpToElementwise
789788
return success();
790789
}
791790
ContractOpToElementwise(
792-
vector::VectorTransformsOptions vectorTransformOptions,
791+
vector::VectorContractLowering vectorContractLowering,
793792
MLIRContext *context, PatternBenefit benefit = 1,
794793
const FilterConstraintType &constraint = defaultFilter)
795794
: MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
796-
vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {}
795+
vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
797796

798797
FailureOr<Value>
799798
matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
@@ -806,8 +805,7 @@ struct ContractOpToElementwise
806805
if (failed(filter(contractOp)))
807806
return failure();
808807

809-
if (vectorTransformOptions.vectorContractLowering !=
810-
vector::VectorContractLowering::ParallelArith)
808+
if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
811809
return failure();
812810

813811
ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
@@ -898,7 +896,7 @@ struct ContractOpToElementwise
898896

899897
private:
900898
/// Options to control the vector patterns.
901-
vector::VectorTransformsOptions vectorTransformOptions;
899+
vector::VectorContractLowering vectorContractLowering;
902900
FilterConstraintType filter;
903901
};
904902

@@ -941,25 +939,25 @@ FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
941939
// TODO: implement benefits, cost models.
942940
MLIRContext *ctx = op.getContext();
943941

944-
ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx);
942+
ContractionOpToMatmulOpLowering pat1(vectorContractLoweringOption, ctx);
945943
FailureOr<Value> newVal1 =
946944
pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
947945
if (!failed(newVal1))
948946
return newVal1;
949947

950-
ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx);
948+
ContractionOpToOuterProductOpLowering pat2(vectorContractLoweringOption, ctx);
951949
FailureOr<Value> newVal2 =
952950
pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
953951
if (!failed(newVal2))
954952
return newVal2;
955953

956-
ContractionOpToDotLowering pat3(vectorTransformOptions, ctx);
954+
ContractionOpToDotLowering pat3(vectorContractLoweringOption, ctx);
957955
FailureOr<Value> newVal3 =
958956
pat3.matchAndRewriteMaskableOp(op, maskOp, rewriter);
959957
if (!failed(newVal3))
960958
return newVal3;
961959

962-
ContractOpToElementwise pat4(vectorTransformOptions, ctx);
960+
ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
963961
FailureOr<Value> newVal4 =
964962
pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
965963
if (!failed(newVal4))
@@ -1292,8 +1290,7 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
12921290
if (maskOp)
12931291
return failure();
12941292

1295-
if (vectorTransformOptions.vectorContractLowering !=
1296-
vector::VectorContractLowering::Matmul)
1293+
if (vectorContractLowering != vector::VectorContractLowering::Matmul)
12971294
return failure();
12981295
if (failed(filter(op)))
12991296
return failure();
@@ -1382,13 +1379,14 @@ FailureOr<Value> ContractionOpToMatmulOpLowering::matchAndRewriteMaskableOp(
13821379
} // namespace
13831380

13841381
void mlir::vector::populateVectorContractLoweringPatterns(
1385-
RewritePatternSet &patterns, VectorTransformsOptions options,
1386-
PatternBenefit benefit, bool disableOuterProductLowering) {
1382+
RewritePatternSet &patterns,
1383+
VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
1384+
bool disableOuterProductLowering) {
13871385
if (!disableOuterProductLowering)
13881386
patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
13891387
patterns.add<ContractionOpLowering, ContractionOpToMatmulOpLowering,
13901388
ContractionOpToOuterProductOpLowering>(
1391-
options, patterns.getContext(), benefit);
1389+
vectorContractLoweringOption, patterns.getContext(), benefit);
13921390
}
13931391

13941392
void mlir::vector::populateVectorOuterProductLoweringPatterns(

0 commit comments

Comments
 (0)