Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 67 additions & 18 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -790,40 +790,89 @@ def Vector_FMAOp :
}];
}

def Vector_ToElementsOp : Vector_Op<"to_elements", [
Pure,
ShapedTypeMatchesElementCountAndTypes<"source", "elements">]> {
let summary = "operation that decomposes a vector into all its scalar elements";
let description = [{
This operation decomposes all the scalar elements from a vector. The
decomposed scalar elements are returned in row-major order. The number of
scalar results must match the number of elements in the input vector type.
All the result elements have the same result type, which must match the
element type of the input vector. Scalable vectors are not supported.
Comment on lines +798 to +802
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it important that it decomposes into all elements? This op could be really useful for unrolling a dimension if we could do it dimwise. Something like:

%0:16 = vector.to_elements %v : vector<16x4xf32> -> vector<4xf32>

This should have the exact same semantics as vector.extract, just doing multiple extracts at once.

I would much rather have this form of the operation, it is much closer to vector.extract and works for N-D vectors much better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that keeping the symmetry with from_elements is valuable. I'm not sure I follow the suggestion, but is it doing something that chaining extract / extract_strided_slice / shape_cast / to_elements cannot achieve?


Examples:

```mlir
// Decompose a 0-D vector.
%0 = vector.to_elements %v0 : vector<f32>
// %0 = %v0[0]

// Decompose a 1-D vector.
%0:2 = vector.to_elements %v1 : vector<2xf32>
// %0#0 = %v1[0]
// %0#1 = %v1[1]

// Decompose a 2-D.
%0:6 = vector.to_elements %v2 : vector<2x3xf32>
// %0#0 = %v2[0, 0]
// %0#1 = %v2[0, 1]
// %0#2 = %v2[0, 2]
// %0#3 = %v2[1, 0]
// %0#4 = %v2[1, 1]
// %0#5 = %v2[1, 2]

// Decompose a 3-D vector.
%0:6 = vector.to_elements %v3 : vector<3x1x2xf32>
// %0#0 = %v3[0, 0, 0]
// %0#1 = %v3[0, 0, 1]
// %0#2 = %v3[1, 0, 0]
// %0#3 = %v3[1, 0, 1]
// %0#4 = %v3[2, 0, 0]
// %0#5 = %v3[2, 0, 1]
```
}];

let arguments = (ins AnyVectorOfAnyRank:$source);
let results = (outs Variadic<AnyType>:$elements);
let assemblyFormat = "$source attr-dict `:` type($source)";
}

def Vector_FromElementsOp : Vector_Op<"from_elements", [
Pure,
TypesMatchWith<"operand types match result element type",
"result", "elements", "SmallVector<Type>("
"::llvm::cast<VectorType>($_self).getNumElements(), "
"::llvm::cast<VectorType>($_self).getElementType())">]> {
ShapedTypeMatchesElementCountAndTypes<"dest", "elements">]> {
let summary = "operation that defines a vector from scalar elements";
let description = [{
This operation defines a vector from one or multiple scalar elements. The
number of elements must match the number of elements in the result type.
All elements must have the same type, which must match the element type of
the result vector type.

`elements` are a flattened version of the result vector in row-major order.
scalar elements are arranged in row-major within the vector. The number of
elements must match the number of elements in the result type. All elements
must have the same type, which must match the element type of the result
vector type. Scalable vectors are not supported.

Example:
Examples:

```mlir
// %f1
// Define a 0-D vector.
%0 = vector.from_elements %f1 : vector<f32>
// [%f1, %f2]
// [%f1]

// Define a 1-D vector.
%1 = vector.from_elements %f1, %f2 : vector<2xf32>
// [[%f1, %f2, %f3], [%f4, %f5, %f6]]
// [%f1, %f2]

// Define a 2-D vector.
%2 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<2x3xf32>
// [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
// [[%f1, %f2, %f3], [%f4, %f5, %f6]]

// Define a 3-D vector.
%3 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<3x1x2xf32>
// [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]]
```

Note, scalable vectors are not supported.
}];

let arguments = (ins Variadic<AnyType>:$elements);
let results = (outs AnyFixedVectorOfAnyRank:$result);
let assemblyFormat = "$elements attr-dict `:` type($result)";
let results = (outs AnyFixedVectorOfAnyRank:$dest);
let assemblyFormat = "$elements attr-dict `:` type($dest)";
let hasCanonicalizer = 1;
}

Expand Down
19 changes: 19 additions & 0 deletions mlir/include/mlir/IR/OpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,25 @@ class AllShapesMatch<list<string> names> :
class AllTypesMatch<list<string> names> :
AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;

// A type constraint that verifies that a shaped type matches the size and
// element type of a container with element types. More specifically, it denotes
// shapedArg.getType().getNumElements() == elementsArg.size() &&
// shapedArg.getType().getElementType() == elementsArg[i].getType(), for i in
// [0, elementsArg.size()).
class ShapedTypeMatchesElementCountAndTypes<string shapedArg,
string elementsArg> :
PredOpTrait<"shaped type '" # shapedArg # "' matches '" # elementsArg # "' "
"element count and types",
And<[CPred<ElementCount<shapedArg>.result # " == "
"$" # elementsArg # ".getTypes().size()">,
CPred<"::llvm::all_of($" # elementsArg # ".getTypes(), "
"[&](::mlir::Type t) { return t == "
# ElementType<shapedArg>.result # "; })">]>> {

string shaped = shapedArg;
string elements = elementsArg;
}

// A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`.
// An optional comparator function may be provided that changes the above form
// into: `comparator(transform(lhs.getType()), rhs.getType())`.
Expand Down
31 changes: 31 additions & 0 deletions mlir/lib/TableGen/Operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,37 @@ void Operator::populateTypeInferenceInfo(
continue;
}

// The `ShapedTypeMatchesElementCountAndTypes` trait represents a 1 -> 1
// type inference edge where a shaped type matches element count and types
// of variadic elements.
if (def.isSubClassOf("ShapedTypeMatchesElementCountAndTypes")) {
StringRef shapedArg = def.getValueAsString("shaped");
StringRef elementsArg = def.getValueAsString("elements");

int shapedIndex = argumentsAndResultsIndex.lookup(shapedArg);
int elementsIndex = argumentsAndResultsIndex.lookup(elementsArg);

// Handle result type inference from shaped type to variadic elements.
if (InferredResultType::isResultIndex(elementsIndex) &&
InferredResultType::isArgIndex(shapedIndex)) {
int resultIndex = InferredResultType::unmapResultIndex(elementsIndex);
ResultTypeInference &infer = inference[resultIndex];
if (!infer.inferred) {
infer.sources.emplace_back(
shapedIndex,
"::llvm::SmallVector<::mlir::Type>(::llvm::cast<::mlir::"
"ShapedType>($_self).getNumElements(), "
"::llvm::cast<::mlir::ShapedType>($_self).getElementType())");
infer.inferred = true;
}
}

// Type inference in the opposite direction is not possible as the actual
// shaped type can't be inferred from the variadic elements.

continue;
}

if (!def.isSubClassOf("AllTypesMatch"))
continue;

Expand Down
24 changes: 20 additions & 4 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1896,7 +1896,24 @@ func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) {

// -----

func.func @invalid_from_elements(%a: f32) {
func.func @to_elements_wrong_num_results(%a: vector<1x1x2xf32>) {
// expected-error @+1 {{operation defines 2 results but was provided 4 to bind}}
%0:4 = vector.to_elements %a : vector<1x1x2xf32>
return
}

// -----

func.func @to_elements_wrong_result_type(%a: vector<2xf32>) -> i32 {
// expected-error @+3 {{use of value '%0' expects different type than prior uses: 'i32'}}
// expected-note @+1 {{prior use here}}
%0:2 = vector.to_elements %a : vector<2xf32>
return %0#0 : i32
}

// -----

func.func @from_elements_wrong_num_operands(%a: f32) {
// expected-error @+1 {{'vector.from_elements' number of operands and types do not match: got 1 operands and 2 types}}
vector.from_elements %a : vector<2xf32>
return
Expand All @@ -1905,16 +1922,15 @@ func.func @invalid_from_elements(%a: f32) {
// -----

// expected-note @+1 {{prior use here}}
func.func @invalid_from_elements(%a: f32, %b: i32) {
func.func @from_elements_wrong_operand_type(%a: f32, %b: i32) {
// expected-error @+1 {{use of value '%b' expects different type than prior uses: 'f32' vs 'i32'}}
vector.from_elements %a, %b : vector<2xf32>
return
}

// -----

func.func @invalid_from_elements_scalable(%a: f32, %b: i32) {
// expected-error @+1 {{'result' must be fixed-length vector of any type values, but got 'vector<[2]xf32>'}}
// expected-error @+1 {{'dest' must be fixed-length vector of any type values, but got 'vector<[2]xf32>'}}
vector.from_elements %a, %b : vector<[2]xf32>
return
}
Expand Down
19 changes: 19 additions & 0 deletions mlir/test/Dialect/Vector/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,25 @@ func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4
return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32>
}

// CHECK-LABEL: func @to_elements(
// CHECK-SAME: %[[A_VEC:.*]]: vector<f32>, %[[B_VEC:.*]]: vector<4xf32>,
// CHECK-SAME: %[[C_VEC:.*]]: vector<1xf32>, %[[D_VEC:.*]]: vector<2x2xf32>)
func.func @to_elements(%a_vec : vector<f32>, %b_vec : vector<4xf32>, %c_vec : vector<1xf32>, %d_vec : vector<2x2xf32>)
-> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
// CHECK: %[[A_ELEMS:.*]] = vector.to_elements %[[A_VEC]] : vector<f32>
%0 = vector.to_elements %a_vec : vector<f32>
// CHECK: %[[B_ELEMS:.*]]:4 = vector.to_elements %[[B_VEC]] : vector<4xf32>
%1:4 = vector.to_elements %b_vec : vector<4xf32>
// CHECK: %[[C_ELEMS:.*]] = vector.to_elements %[[C_VEC]] : vector<1xf32>
%2 = vector.to_elements %c_vec : vector<1xf32>
// CHECK: %[[D_ELEMS:.*]]:4 = vector.to_elements %[[D_VEC]] : vector<2x2xf32>
%3:4 = vector.to_elements %d_vec : vector<2x2xf32>
// CHECK: return %[[A_ELEMS]], %[[B_ELEMS]]#0, %[[B_ELEMS]]#1, %[[B_ELEMS]]#2,
// CHECK-SAME: %[[B_ELEMS]]#3, %[[C_ELEMS]], %[[D_ELEMS]]#0, %[[D_ELEMS]]#1,
// CHECK-SAME: %[[D_ELEMS]]#2, %[[D_ELEMS]]#3
return %0, %1#0, %1#1, %1#2, %1#3, %2, %3#0, %3#1, %3#2, %3#3 : f32, f32, f32, f32, f32, f32, f32, f32, f32, f32
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We seem to be mixing styles in this file:

  • One Op per function vs Multiple Ops per function.

Definitely not asking you to change that, just pointing out unfortunate inconsistency.

Now, I am thinking though that it would be good to make this symmetrical to @from_elements and use identical shapes. Specifically, this example takes two 1D vectors. Do we need to? Wouldn't a combination of 0D, 1D, 2D make more sense?

TBH, one example with 0D and 2D would IMHO be sufficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made the test symmetrical but I'm not sure we should too mechanical about this. A bit of randomness is helpful to increase coverage and expose bugs because... we never know where they could be... I also like to test the 0-D and 1-element 1-D because there are some nuances with those types, as you well know :)


// CHECK-LABEL: func @from_elements(
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32)
func.func @from_elements(%a: f32, %b: f32) -> (vector<f32>, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>) {
Expand Down
26 changes: 26 additions & 0 deletions mlir/tools/mlir-tblgen/OpFormatGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2787,6 +2787,11 @@ class OpFormatParser : public FormatParser {
void handleTypesMatchConstraint(
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);

/// Check for inferable type resolution based on
/// `ShapedTypeMatchesElementCountAndTypes` constraint.
void handleShapedTypeMatchesElementCountAndTypesConstraint(
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def);

/// Returns an argument or attribute with the given name that has been seen
/// within the format.
ConstArgument findSeenArg(StringRef name);
Expand Down Expand Up @@ -2850,6 +2855,9 @@ LogicalResult OpFormatParser::verify(SMLoc loc,
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
} else if (def.isSubClassOf("TypesMatchWith")) {
handleTypesMatchConstraint(variableTyResolver, def);
} else if (def.isSubClassOf("ShapedTypeMatchesElementCountAndTypes")) {
handleShapedTypeMatchesElementCountAndTypesConstraint(variableTyResolver,
def);
} else if (!op.allResultTypesKnown()) {
// This doesn't check the name directly to handle
// DeclareOpInterfaceMethods<InferTypeOpInterface>
Expand Down Expand Up @@ -3289,6 +3297,24 @@ void OpFormatParser::handleTypesMatchConstraint(
variableTyResolver[rhsName] = {arg, transformer};
}

void OpFormatParser::handleShapedTypeMatchesElementCountAndTypesConstraint(
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) {
StringRef shapedArg = def.getValueAsString("shaped");
StringRef elementsArg = def.getValueAsString("elements");

// Check if the 'shaped' argument is seen, then we can infer the 'elements'
// types.
if (ConstArgument arg = findSeenArg(shapedArg)) {
variableTyResolver[elementsArg] = {
arg, "::llvm::SmallVector<::mlir::Type>(::llvm::cast<::mlir::"
"ShapedType>($_self).getNumElements(), "
"::llvm::cast<::mlir::ShapedType>($_self).getElementType())"};
}

// Type inference in the opposite direction is not possible as the actual
// shaped type can't be inferred from the variadic elements.
}

ConstArgument OpFormatParser::findSeenArg(StringRef name) {
if (const NamedTypeConstraint *arg = findArg(op.getOperands(), name))
return seenOperandTypes.test(arg - op.operand_begin()) ? arg : nullptr;
Expand Down
Loading