Skip to content

Commit dd8a9e2

Browse files
authored
[mlir][vector] Remove vector.reshape operation (llvm#101645)
This operation was added five years ago and has no lowerings or uses within upstream MLIR (and no reported uses downstream). There’s only a handful of round-trip tests. See related RFC: https://discourse.llvm.org/t/rfc-should-vector-reshape-be-removed/80478/3
1 parent 72fb188 commit dd8a9e2

File tree

4 files changed

+0
-255
lines changed

4 files changed

+0
-255
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 0 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,122 +1178,6 @@ def Vector_OuterProductOp :
11781178
let hasVerifier = 1;
11791179
}
11801180

1181-
// TODO: Add transformation which decomposes ReshapeOp into an optimized
1182-
// sequence of vector rotate/shuffle/select operations.
1183-
def Vector_ReshapeOp :
1184-
Vector_Op<"reshape", [AttrSizedOperandSegments, Pure]>,
1185-
Arguments<(ins AnyVector:$vector, Variadic<Index>:$input_shape,
1186-
Variadic<Index>:$output_shape,
1187-
I64ArrayAttr:$fixed_vector_sizes)>,
1188-
Results<(outs AnyVector:$result)> {
1189-
let summary = "vector reshape operation";
1190-
let description = [{
1191-
Reshapes its vector operand from 'input_shape' to 'output_shape' maintaining
1192-
fixed vector dimension 'fixed_vector_sizes' on the innermost vector
1193-
dimensions.
1194-
1195-
The parameters 'input_shape' and 'output_shape' represent valid data shapes
1196-
across fixed vector shapes. For example, if a vector has a valid data
1197-
shape [6] with fixed vector size [8], then the valid data elements are
1198-
assumed to be stored at the beginning of the vector with the remaining
1199-
vector elements undefined.
1200-
1201-
In the examples below, valid data elements are represented by an alphabetic
1202-
character, and undefined data elements are represented by '-'.
1203-
1204-
Example
1205-
1206-
vector<1x8xf32> with valid data shape [6], fixed vector sizes [8]
1207-
1208-
input: [a, b, c, d, e, f]
1209-
1210-
layout map: (d0) -> (d0 floordiv 8, d0 mod 8)
1211-
1212-
vector layout: [a, b, c, d, e, f, -, -]
1213-
1214-
Example
1215-
1216-
vector<2x8xf32> with valid data shape [10], fixed vector sizes [8]
1217-
1218-
input: [a, b, c, d, e, f, g, h, i, j]
1219-
1220-
layout map: (d0) -> (d0 floordiv 8, d0 mod 8)
1221-
1222-
vector layout: [[a, b, c, d, e, f, g, h],
1223-
[i, j, -, -, -, -, -, -]]
1224-
1225-
Example
1226-
1227-
vector<2x2x2x3xf32> with valid data shape [3, 5], fixed vector sizes
1228-
[2, 3]
1229-
1230-
input: [[a, b, c, d, e],
1231-
[f, g, h, i, j],
1232-
[k, l, m, n, o]]
1233-
1234-
layout map: (d0, d1) -> (d0 floordiv 3, d1 floordiv 5,
1235-
d0 mod 3, d1 mod 5)
1236-
1237-
vector layout: [[[[a, b, c],
1238-
[f, g, h]]
1239-
[[d, e, -],
1240-
[i, j, -]]],
1241-
[[[k, l, m],
1242-
[-, -, -]]
1243-
[[n, o, -],
1244-
[-, -, -]]]]
1245-
1246-
Example
1247-
1248-
%1 = vector.reshape %0, [%c3, %c6], [%c2, %c9], [4]
1249-
: vector<3x2x4xf32> to vector<2x3x4xf32>
1250-
1251-
input: [[a, b, c, d, e, f],
1252-
[g, h, i, j, k, l],
1253-
[m, n, o, p, q, r]]
1254-
1255-
layout map: (d0, d1) -> (d0, d1 floordiv 4, d1 mod 4)
1256-
1257-
1258-
Input vector: [[[a, b, c, d],
1259-
[e, f, -, -]],
1260-
[[g, h, i, j],
1261-
[k, l, -, -]],
1262-
[[m, n, o, p],
1263-
[q, r, -, -]]]
1264-
1265-
Output vector: [[[a, b, c, d],
1266-
[e, f, g, h],
1267-
[i, -, -, -]],
1268-
[[j, k, l, m],
1269-
[n, o, p, q],
1270-
[r, -, -, -]]]
1271-
}];
1272-
1273-
let extraClassDeclaration = [{
1274-
VectorType getInputVectorType() {
1275-
return ::llvm::cast<VectorType>(getVector().getType());
1276-
}
1277-
VectorType getOutputVectorType() {
1278-
return ::llvm::cast<VectorType>(getResult().getType());
1279-
}
1280-
1281-
/// Returns as integer value the number of input shape operands.
1282-
int64_t getNumInputShapeSizes() { return getInputShape().size(); }
1283-
1284-
/// Returns as integer value the number of output shape operands.
1285-
int64_t getNumOutputShapeSizes() { return getOutputShape().size(); }
1286-
1287-
void getFixedVectorSizes(SmallVectorImpl<int64_t> &results);
1288-
}];
1289-
1290-
let assemblyFormat = [{
1291-
$vector `,` `[` $input_shape `]` `,` `[` $output_shape `]` `,`
1292-
$fixed_vector_sizes attr-dict `:` type($vector) `to` type($result)
1293-
}];
1294-
let hasVerifier = 1;
1295-
}
1296-
12971181
def Vector_ExtractStridedSliceOp :
12981182
Vector_Op<"extract_strided_slice", [Pure,
12991183
PredOpTrait<"operand and result have same element type",

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3327,68 +3327,6 @@ Type OuterProductOp::getExpectedMaskType() {
33273327
vecType.getScalableDims());
33283328
}
33293329

3330-
//===----------------------------------------------------------------------===//
3331-
// ReshapeOp
3332-
//===----------------------------------------------------------------------===//
3333-
3334-
LogicalResult ReshapeOp::verify() {
3335-
// Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
3336-
auto inputVectorType = getInputVectorType();
3337-
auto outputVectorType = getOutputVectorType();
3338-
int64_t inputShapeRank = getNumInputShapeSizes();
3339-
int64_t outputShapeRank = getNumOutputShapeSizes();
3340-
SmallVector<int64_t, 4> fixedVectorSizes;
3341-
getFixedVectorSizes(fixedVectorSizes);
3342-
int64_t numFixedVectorSizes = fixedVectorSizes.size();
3343-
3344-
if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
3345-
return emitError("invalid input shape for vector type ") << inputVectorType;
3346-
3347-
if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
3348-
return emitError("invalid output shape for vector type ")
3349-
<< outputVectorType;
3350-
3351-
// Verify that the 'fixedVectorSizes' match an input/output vector shape
3352-
// suffix.
3353-
unsigned inputVectorRank = inputVectorType.getRank();
3354-
for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
3355-
unsigned index = inputVectorRank - numFixedVectorSizes - i;
3356-
if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
3357-
return emitError("fixed vector size must match input vector for dim ")
3358-
<< i;
3359-
}
3360-
3361-
unsigned outputVectorRank = outputVectorType.getRank();
3362-
for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
3363-
unsigned index = outputVectorRank - numFixedVectorSizes - i;
3364-
if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
3365-
return emitError("fixed vector size must match output vector for dim ")
3366-
<< i;
3367-
}
3368-
3369-
// If all shape operands are produced by constant ops, verify that product
3370-
// of dimensions for input/output shape match.
3371-
auto isDefByConstant = [](Value operand) {
3372-
return getConstantIntValue(operand).has_value();
3373-
};
3374-
if (llvm::all_of(getInputShape(), isDefByConstant) &&
3375-
llvm::all_of(getOutputShape(), isDefByConstant)) {
3376-
int64_t numInputElements = 1;
3377-
for (auto operand : getInputShape())
3378-
numInputElements *= getConstantIntValue(operand).value();
3379-
int64_t numOutputElements = 1;
3380-
for (auto operand : getOutputShape())
3381-
numOutputElements *= getConstantIntValue(operand).value();
3382-
if (numInputElements != numOutputElements)
3383-
return emitError("product of input and output shape sizes must match");
3384-
}
3385-
return success();
3386-
}
3387-
3388-
void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
3389-
populateFromInt64AttrArray(getFixedVectorSizes(), results);
3390-
}
3391-
33923330
//===----------------------------------------------------------------------===//
33933331
// ExtractStridedSliceOp
33943332
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,66 +1094,6 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
10941094

10951095
// -----
10961096

1097-
func.func @reshape_bad_input_shape(%arg0 : vector<3x2x4xf32>) {
1098-
%c2 = arith.constant 2 : index
1099-
%c3 = arith.constant 3 : index
1100-
%c6 = arith.constant 6 : index
1101-
%c9 = arith.constant 9 : index
1102-
// expected-error@+1 {{invalid input shape for vector type}}
1103-
%1 = vector.reshape %arg0, [%c3, %c6, %c3], [%c2, %c9], [4]
1104-
: vector<3x2x4xf32> to vector<2x3x4xf32>
1105-
}
1106-
1107-
// -----
1108-
1109-
func.func @reshape_bad_output_shape(%arg0 : vector<3x2x4xf32>) {
1110-
%c2 = arith.constant 2 : index
1111-
%c3 = arith.constant 3 : index
1112-
%c6 = arith.constant 6 : index
1113-
%c9 = arith.constant 9 : index
1114-
// expected-error@+1 {{invalid output shape for vector type}}
1115-
%1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9, %c3], [4]
1116-
: vector<3x2x4xf32> to vector<2x3x4xf32>
1117-
}
1118-
1119-
// -----
1120-
1121-
func.func @reshape_bad_input_output_shape_product(%arg0 : vector<3x2x4xf32>) {
1122-
%c2 = arith.constant 2 : index
1123-
%c3 = arith.constant 3 : index
1124-
%c6 = arith.constant 6 : index
1125-
%c9 = arith.constant 9 : index
1126-
// expected-error@+1 {{product of input and output shape sizes must match}}
1127-
%1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c6], [4]
1128-
: vector<3x2x4xf32> to vector<2x3x4xf32>
1129-
}
1130-
1131-
// -----
1132-
1133-
func.func @reshape_bad_input_fixed_size(%arg0 : vector<3x2x5xf32>) {
1134-
%c2 = arith.constant 2 : index
1135-
%c3 = arith.constant 3 : index
1136-
%c6 = arith.constant 6 : index
1137-
%c9 = arith.constant 9 : index
1138-
// expected-error@+1 {{fixed vector size must match input vector for dim 0}}
1139-
%1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4]
1140-
: vector<3x2x5xf32> to vector<2x3x4xf32>
1141-
}
1142-
1143-
// -----
1144-
1145-
func.func @reshape_bad_output_fixed_size(%arg0 : vector<3x2x4xf32>) {
1146-
%c2 = arith.constant 2 : index
1147-
%c3 = arith.constant 3 : index
1148-
%c6 = arith.constant 6 : index
1149-
%c9 = arith.constant 9 : index
1150-
// expected-error@+1 {{fixed vector size must match output vector for dim 0}}
1151-
%1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4]
1152-
: vector<3x2x4xf32> to vector<2x3x5xf32>
1153-
}
1154-
1155-
// -----
1156-
11571097
func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
11581098
// expected-error@+1 {{op source/result vectors must have same element type}}
11591099
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -522,23 +522,6 @@ func.func @vector_print_on_scalar(%arg0: i64) {
522522
return
523523
}
524524

525-
// CHECK-LABEL: @reshape
526-
func.func @reshape(%arg0 : vector<3x2x4xf32>) -> (vector<2x3x4xf32>) {
527-
// CHECK: %[[C2:.*]] = arith.constant 2 : index
528-
%c2 = arith.constant 2 : index
529-
// CHECK: %[[C3:.*]] = arith.constant 3 : index
530-
%c3 = arith.constant 3 : index
531-
// CHECK: %[[C6:.*]] = arith.constant 6 : index
532-
%c6 = arith.constant 6 : index
533-
// CHECK: %[[C9:.*]] = arith.constant 9 : index
534-
%c9 = arith.constant 9 : index
535-
// CHECK: vector.reshape %{{.*}}, [%[[C3]], %[[C6]]], [%[[C2]], %[[C9]]], [4] : vector<3x2x4xf32> to vector<2x3x4xf32>
536-
%1 = vector.reshape %arg0, [%c3, %c6], [%c2, %c9], [4]
537-
: vector<3x2x4xf32> to vector<2x3x4xf32>
538-
539-
return %1 : vector<2x3x4xf32>
540-
}
541-
542525
// CHECK-LABEL: @shape_cast
543526
func.func @shape_cast(%arg0 : vector<5x1x3x2xf32>,
544527
%arg1 : vector<8x1xf32>,

0 commit comments

Comments
 (0)