diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 3885439e11f89..e68a3c77881fb 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2595,6 +2595,7 @@ def Vector_MaskOp : Vector_Op<"mask", [ def Vector_TransposeOp : Vector_Op<"transpose", [Pure, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]> { diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 8789f55707267..25ce292f16e45 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6316,6 +6316,11 @@ std::optional> TransposeOp::getShapeForUnroll() { return llvm::to_vector<4>(getResultVectorType().getShape()); } +void TransposeOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + setResultRanges(getResult(), argRanges.front()); +} + namespace { // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir index 2563b48cdd506..f89d307590e0b 100644 --- a/mlir/test/Dialect/Vector/int-range-interface.mlir +++ b/mlir/test/Dialect/Vector/int-range-interface.mlir @@ -51,6 +51,15 @@ func.func @vector_shape_cast() -> vector<4x4xindex> { func.return %2 : vector<4x4xindex> } +// CHECK-LABEL: func @vector_transpose +// CHECK: test.reflect_bounds {smax = 8 : index, smin = 7 : index, umax = 8 : index, umin = 7 : index} +func.func @vector_transpose() -> vector<2x4xindex> { + %0 = test.with_bounds { smax = 8 : index, smin = 7 : index, umax = 8 : index, umin = 7 : index } : vector<4x2xindex> + %1 = vector.transpose %0, [1, 0] : vector<4x2xindex> to vector<2x4xindex> + %2 = test.reflect_bounds %1 : vector<2x4xindex> + func.return %2 : vector<2x4xindex> +} + // CHECK-LABEL: func @vector_extract // CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index} func.func @vector_extract() -> index {