Skip to content

Commit 8c4d201

Browse files
committed
add to_elements builder that infers the scalar type(s)
1 parent b6445ac commit 8c4d201

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
798798
This operation decomposes all the scalar elements from a vector. The
799799
decomposed scalar elements are returned in row-major order. The number of
800800
scalar results must match the number of elements in the input vector type.
801-
All the result elements have the same result type, which must match the
801+
All the result elements have the same type, which must match the
802802
element type of the input vector. Scalable vectors are not supported.
803803

804804
Examples:
@@ -813,7 +813,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
813813
// %0#0 = %v1[0]
814814
// %0#1 = %v1[1]
815815

816-
// Decompose a 2-D.
816+
// Decompose a 2-D vector.
817817
%0:6 = vector.to_elements %v2 : vector<2x3xf32>
818818
// %0#0 = %v2[0, 0]
819819
// %0#1 = %v2[0, 1]
@@ -835,6 +835,13 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
835835

836836
let arguments = (ins AnyVectorOfAnyRank:$source);
837837
let results = (outs Variadic<AnyType>:$elements);
838+
839+
840+
let builders = [
841+
// Build method that infers the result types from `elements`.
842+
OpBuilder<(ins "Value":$elements)>,
843+
];
844+
838845
let assemblyFormat = "$source attr-dict `:` type($source)";
839846
let hasFolder = 1;
840847
}

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2417,6 +2417,14 @@ LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
24172417
return foldToElementsFromElements(*this, results);
24182418
}
24192419

2420+
void vector::ToElementsOp::build(OpBuilder &builder, OperationState &result, Value elements) {
2421+
auto vectorType = cast<VectorType>(elements.getType());
2422+
Type elementType = vectorType.getElementType();
2423+
int64_t nbElements = vectorType.getNumElements();
2424+
SmallVector<Type> scalarTypes(nbElements, elementType);
2425+
build(builder, result, scalarTypes, elements);
2426+
}
2427+
24202428
//===----------------------------------------------------------------------===//
24212429
// FromElementsOp
24222430
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)