|
13 | 13 | //===----------------------------------------------------------------------===//
|
14 | 14 |
|
15 | 15 | #include "mlir/Dialect/Vector/IR/VectorOps.h"
|
| 16 | +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" |
| 17 | +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" |
16 | 18 | #include "mlir/IR/BuiltinOps.h"
|
17 | 19 | #include "mlir/IR/BuiltinTypes.h"
|
| 20 | +#include "mlir/Pass/Pass.h" |
18 | 21 | #include "mlir/Support/LogicalResult.h"
|
| 22 | +#include "mlir/Transforms/DialectConversion.h" |
19 | 23 | #include "llvm/ADT/ArrayRef.h"
|
20 | 24 | #include "llvm/Transforms/Utils/AddDiscriminators.h"
|
21 |
| -#include <cstdint> |
22 |
| -#include <imex/Transforms/Passes.h> |
23 | 25 |
|
24 |
| -#include <mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h> |
25 |
| -#include <mlir/Pass/Pass.h> |
26 |
| -#include <mlir/Transforms/DialectConversion.h> |
| 26 | +#include "imex/Transforms/Passes.h" |
| 27 | + |
| 28 | +#include <cstdint> |
27 | 29 | #include <numeric>
|
28 | 30 |
|
29 | 31 | namespace imex {
|
@@ -227,15 +229,23 @@ struct VectorLinearizePass final
|
227 | 229 | mlir::RewritePatternSet patterns(context);
|
228 | 230 | mlir::ConversionTarget target(*context);
|
229 | 231 |
|
| 232 | + typeConverter.addConversion([](mlir::Type type) { return type; }); |
| 233 | + |
230 | 234 | target.addDynamicallyLegalOp<mlir::vector::ShuffleOp>([&](mlir::Operation
|
231 | 235 | *op) {
|
232 | 236 | return op->getResult(0).getType().cast<mlir::VectorType>().getRank() == 1;
|
233 | 237 | });
|
234 | 238 |
|
| 239 | + target.addIllegalOp<mlir::vector::TransposeOp>(); |
| 240 | + target.addLegalOp<mlir::vector::ShapeCastOp>(); |
| 241 | + |
235 | 242 | patterns.add<VectorExtractStridedSliceConversion, VectorShffleOpConversion,
|
236 | 243 | VectorExtractOpConversion>(typeConverter, context);
|
237 | 244 |
|
238 |
| - typeConverter.addConversion([](mlir::Type type) { return type; }); |
| 245 | + mlir::vector::populateVectorTransposeLoweringPatterns( |
| 246 | + patterns, |
| 247 | + mlir::vector::VectorTransformsOptions().setVectorTransposeLowering( |
| 248 | + mlir::vector::VectorTransposeLowering::Shuffle1D)); |
239 | 249 | mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
|
240 | 250 | typeConverter, patterns, target);
|
241 | 251 | if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
|
|
0 commit comments