Skip to content

Commit ee717ee

Browse files
committed
[mlir][linalg] Add comp-type to new elementwise-op.
1 parent 5cc2ae0 commit ee717ee

File tree

4 files changed

+93
-8
lines changed

4 files changed

+93
-8
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,13 +563,16 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
563563
The number of dims of the iterator-types are inferred from the rank of
564564
the result type.
565565

566+
Numeric casting is performed on the input operand, promoting it to the same
567+
data type as the result.
568+
566569
Example:
567570

568571
Defining a unary linalg.elemwise with default indexing-map:
569572
```mlir
570573
%exp = linalg.elemwise
571574
kind=#linalg.elemwise_kind<exp>
572-
ins(%x : tensor<4x16x8xf32>)
575+
ins(%x : tensor<4x16x8xf16>)
573576
outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32>
574577
```
575578

@@ -587,7 +590,8 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
587590
Variadic<AnyType>:$inputs,
588591
Variadic<AnyShaped>:$outputs,
589592
ElementwiseKindAttr:$kind,
590-
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps
593+
DefaultValuedOptionalAttr<AffineMapArrayAttr, "{}">:$indexing_maps,
594+
DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast
591595
);
592596

593597
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4250,17 +4250,36 @@ void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
42504250
SmallVector<Value> yields;
42514251
Value result;
42524252

4253+
TypeFn castVal = TypeFn::cast_signed;
4254+
auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4255+
return attr.getName() == "cast";
4256+
});
4257+
4258+
if (castIter != attrs.end()) {
4259+
if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4260+
castVal = attr.getValue();
4261+
}
4262+
42534263
if (arityGroup == ElementwiseArityGroup::Unary) {
4254-
result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
4264+
Value val0 = helper.buildTypeFn(castVal, block.getArgument(1).getType(),
4265+
block.getArgument(0));
4266+
result = helper.buildUnaryFn(kind.unaryFn, val0);
42554267

42564268
} else if (arityGroup == ElementwiseArityGroup::Binary) {
4257-
result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
4258-
block.getArgument(1));
4269+
Value val0 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
4270+
block.getArgument(0));
4271+
Value val1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
4272+
block.getArgument(1));
4273+
result = helper.buildBinaryFn(kind.binaryFn, val0, val1);
42594274

42604275
} else if (arityGroup == ElementwiseArityGroup::Ternary) {
4261-
result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
4262-
block.getArgument(1), block.getArgument(2));
4263-
4276+
// select op's select-arg (block arg 0) must remain bool.
4277+
Value val1 = helper.buildTypeFn(castVal, block.getArgument(3).getType(),
4278+
block.getArgument(1));
4279+
Value val2 = helper.buildTypeFn(castVal, block.getArgument(3).getType(),
4280+
block.getArgument(2));
4281+
result =
4282+
helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0), val1, val2);
42644283
} else
42654284
assert(false && "found unhandled category in elemwise");
42664285

mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,27 @@ func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8
163163
outs(%D: tensor<8x16x32xf32>) -> tensor<8x16x32xf32>
164164
return %r : tensor<8x16x32xf32>
165165
}
166+
167+
// -----
168+
169+
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
170+
//
171+
// CHECK: @cast_f16_to_f32(%[[A:.+]]: tensor<16x8xf16>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>)
172+
// CHECK: linalg.generic
173+
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
174+
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
175+
// CHECK-SAME: ins(%[[A]], %[[B]]
176+
// CHECK-SAME: outs(%[[C]]
177+
//
178+
// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
179+
// CHECK: %[[CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32
180+
// CHECK: %[[MUL:.+]] = arith.mulf %[[CAST]], %[[B_ARG]] : f32
181+
// CHECK: linalg.yield %[[MUL]] : f32
182+
//
183+
func.func @cast_f16_to_f32(%A : tensor<16x8xf16>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> {
184+
%r = linalg.elementwise
185+
kind=#linalg.elementwise_kind<mul>
186+
ins(%A, %B: tensor<16x8xf16>, tensor<16x8xf32>)
187+
outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
188+
return %r : tensor<16x8xf32>
189+
}

mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,41 @@ func.func @redundant_maps(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>,
8888
outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32>
8989
return %r : tensor<1x2x3x4x5xi32>
9090
}
91+
92+
// -----
93+
94+
// CHECK: @convert_f16_to_f32(%[[A:.+]]: tensor<16x8xf16>, %[[B:.+]]: tensor<16x8xf32>,
95+
// CHECK-SAME: %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> {
96+
// CHECK: {{.*}} = linalg.elementwise
97+
// CHECK-SAME: kind=#linalg.elementwise_kind<div>
98+
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf16>, tensor<16x8xf32>)
99+
// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32>
100+
//
101+
func.func @convert_f16_to_f32(%A: tensor<16x8xf16>, %B: tensor<16x8xf32>,
102+
%C: tensor<16x8xf32>) -> tensor<16x8xf32> {
103+
%r = linalg.elementwise
104+
kind=#linalg.elementwise_kind<div>
105+
ins(%A, %B: tensor<16x8xf16>, tensor<16x8xf32>)
106+
outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32>
107+
return %r : tensor<16x8xf32>
108+
}
109+
110+
111+
// -----
112+
113+
// CHECK: @explicit_cast(%[[A:.+]]: tensor<16x8xi16>, %[[B:.+]]: tensor<16x8xi32>,
114+
// CHECK-SAME: %[[C:.+]]: tensor<16x8xi32>) -> tensor<16x8xi32> {
115+
// CHECK: {{.*}} = linalg.elementwise
116+
// CHECK-SAME: kind=#linalg.elementwise_kind<add>
117+
// CHECK-SAME: {cast = #linalg.type_fn<cast_signed>}
118+
// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xi16>, tensor<16x8xi32>)
119+
// CHECK-SAME: outs(%[[C]] : tensor<16x8xi32>) -> tensor<16x8xi32>
120+
//
121+
func.func @explicit_cast(%A: tensor<16x8xi16>, %B: tensor<16x8xi32>, %C: tensor<16x8xi32>) -> tensor<16x8xi32> {
122+
%0 = linalg.elementwise
123+
kind=#linalg.elementwise_kind<add>
124+
{cast = #linalg.type_fn<cast_signed>}
125+
ins(%A, %B : tensor<16x8xi16>, tensor<16x8xi32>)
126+
outs(%C : tensor<16x8xi32>) -> tensor<16x8xi32>
127+
return %0 : tensor<16x8xi32>
128+
}

0 commit comments

Comments
 (0)