Skip to content

Commit 3da78dc

Browse files
authored
Linearize vector.transpose to vector.shuffle (#714)
Upstream vector dialect already have transform which lowers transpose to linearized shuffle. Add it to our linearization pipeline.
1 parent f87f6f7 commit 3da78dc

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

lib/Transforms/VectorLinearize.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,19 @@
1313
//===----------------------------------------------------------------------===//
1414

1515
#include "mlir/Dialect/Vector/IR/VectorOps.h"
16+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17+
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
1618
#include "mlir/IR/BuiltinOps.h"
1719
#include "mlir/IR/BuiltinTypes.h"
20+
#include "mlir/Pass/Pass.h"
1821
#include "mlir/Support/LogicalResult.h"
22+
#include "mlir/Transforms/DialectConversion.h"
1923
#include "llvm/ADT/ArrayRef.h"
2024
#include "llvm/Transforms/Utils/AddDiscriminators.h"
21-
#include <cstdint>
22-
#include <imex/Transforms/Passes.h>
2325

24-
#include <mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h>
25-
#include <mlir/Pass/Pass.h>
26-
#include <mlir/Transforms/DialectConversion.h>
26+
#include "imex/Transforms/Passes.h"
27+
28+
#include <cstdint>
2729
#include <numeric>
2830

2931
namespace imex {
@@ -227,15 +229,23 @@ struct VectorLinearizePass final
227229
mlir::RewritePatternSet patterns(context);
228230
mlir::ConversionTarget target(*context);
229231

232+
typeConverter.addConversion([](mlir::Type type) { return type; });
233+
230234
target.addDynamicallyLegalOp<mlir::vector::ShuffleOp>([&](mlir::Operation
231235
*op) {
232236
return op->getResult(0).getType().cast<mlir::VectorType>().getRank() == 1;
233237
});
234238

239+
target.addIllegalOp<mlir::vector::TransposeOp>();
240+
target.addLegalOp<mlir::vector::ShapeCastOp>();
241+
235242
patterns.add<VectorExtractStridedSliceConversion, VectorShffleOpConversion,
236243
VectorExtractOpConversion>(typeConverter, context);
237244

238-
typeConverter.addConversion([](mlir::Type type) { return type; });
245+
mlir::vector::populateVectorTransposeLoweringPatterns(
246+
patterns,
247+
mlir::vector::VectorTransformsOptions().setVectorTransposeLowering(
248+
mlir::vector::VectorTransposeLowering::Shuffle1D));
239249
mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
240250
typeConverter, patterns, target);
241251
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,

test/Transforms/vector-linearize.mlir

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,18 @@ func.func @test_vector_shuffle(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>) -
9797
func.func @test_vector_extract(%arg0: vector<2x8x4xf32>) -> vector<8x4xf32> {
9898
%0 = vector.extract %arg0[1]: vector<8x4xf32> from vector<2x8x4xf32>
9999
return %0 : vector<8x4xf32>
100-
}
100+
}
101+
102+
// -----
103+
104+
// CHECK-LABEL: test_vector_transpose
105+
// CHECK-SAME: (%[[ORIG_ARG:.*]]: vector<2x8xf32>) -> vector<8x2xf32>
106+
// CHECK: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x8xf32> to vector<16xf32>
107+
// CHECK: %[[SHUFFLE:.*]] = vector.shuffle %[[ARG]], %[[ARG]]
108+
// CHECK: [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<16xf32>, vector<16xf32>
109+
// CHECK: %[[RES:.*]] = vector.shape_cast %[[SHUFFLE]] : vector<16xf32> to vector<8x2xf32>
110+
// CHECK: return %[[RES]] : vector<8x2xf32>
111+
func.func @test_vector_transpose(%arg: vector<2x8xf32>) -> vector<8x2xf32> {
112+
%0 = vector.transpose %arg, [1, 0] : vector<2x8xf32> to vector<8x2xf32>
113+
return %0 : vector<8x2xf32>
114+
}

0 commit comments

Comments
 (0)