diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 308e39a9a51e1..af85daca1c078 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -563,13 +563,16 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [ The number of dims of the iterator-types are inferred from the rank of the result type. + Numeric casting is performed on the input operand, promoting it to the same + data type as the result. + Example: Defining a unary linalg.elemwise with default indexing-map: ```mlir %exp = linalg.elemwise kind=#linalg.elemwise_kind - ins(%x : tensor<4x16x8xf32>) + ins(%x : tensor<4x16x8xf16>) outs(%y: tensor<4x16x8xf32>) -> tensor<4x16x8xf32> ``` @@ -587,7 +590,8 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [ Variadic:$inputs, Variadic:$outputs, ElementwiseKindAttr:$kind, - DefaultValuedOptionalAttr:$indexing_maps + DefaultValuedOptionalAttr:$indexing_maps, + DefaultValuedOptionalAttr:$cast ); let results = (outs Variadic:$result_tensors); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 07b19e5cb1a89..0ffa259023faf 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -4250,17 +4250,36 @@ void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block, SmallVector yields; Value result; + TypeFn castVal = TypeFn::cast_signed; + auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) { + return attr.getName() == "cast"; + }); + + if (castIter != attrs.end()) { + if (auto attr = llvm::dyn_cast(castIter->getValue())) + castVal = attr.getValue(); + } + if (arityGroup == ElementwiseArityGroup::Unary) { - result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0)); + Value val0 = helper.buildTypeFn(castVal, block.getArgument(1).getType(), + block.getArgument(0)); + result = helper.buildUnaryFn(kind.unaryFn, val0); } else if (arityGroup == ElementwiseArityGroup::Binary) { - result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0), - block.getArgument(1)); + Value val0 = helper.buildTypeFn(castVal, block.getArgument(2).getType(), + block.getArgument(0)); + Value val1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(), + block.getArgument(1)); + result = helper.buildBinaryFn(kind.binaryFn, val0, val1); } else if (arityGroup == ElementwiseArityGroup::Ternary) { - result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0), - block.getArgument(1), block.getArgument(2)); - + // select op's select-arg (block arg 0) must remain bool. + Value val1 = helper.buildTypeFn(castVal, block.getArgument(3).getType(), + block.getArgument(1)); + Value val2 = helper.buildTypeFn(castVal, block.getArgument(3).getType(), + block.getArgument(2)); + result = + helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0), val1, val2); } else assert(false && "found unhandled category in elemwise"); diff --git a/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir index e884858c016f4..19fb0e61d450b 100644 --- a/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir +++ b/mlir/test/Dialect/Linalg/elementwise/generalize-named-ops.mlir @@ -163,3 +163,27 @@ func.func @ternary(%A : tensor<32x16xi1>, %B: tensor<8x16x32xf32>, %C : tensor<8 outs(%D: tensor<8x16x32xf32>) -> tensor<8x16x32xf32> return %r : tensor<8x16x32xf32> } + +// ----- + +// CHECK: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// +// CHECK: @cast_f16_to_f32(%[[A:.+]]: tensor<16x8xf16>, %[[B:.+]]: tensor<16x8xf32>, %[[C:.+]]: tensor<16x8xf32>) +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[A]], %[[B]] +// CHECK-SAME: outs(%[[C]] +// +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f16, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32) +// CHECK: %[[CAST:.+]] = arith.extf %[[A_ARG]] : f16 to f32 +// CHECK: %[[MUL:.+]] = arith.mulf %[[CAST]], %[[B_ARG]] : f32 +// CHECK: linalg.yield %[[MUL]] : f32 +// +func.func @cast_f16_to_f32(%A : tensor<16x8xf16>, %B: tensor<16x8xf32>, %C: tensor<16x8xf32>) -> tensor<16x8xf32> { + %r = linalg.elementwise + kind=#linalg.elementwise_kind + ins(%A, %B: tensor<16x8xf16>, tensor<16x8xf32>) + outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32> + return %r : tensor<16x8xf32> +} diff --git a/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir b/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir index 20ebdd992b5a1..0bce89ca378a4 100644 --- a/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/elementwise/roundtrip.mlir @@ -88,3 +88,41 @@ func.func @redundant_maps(%A: tensor<1x2x3x4x5xi32>, %B: tensor<1x2x3x4x5xi32>, outs(%C: tensor<1x2x3x4x5xi32>) -> tensor<1x2x3x4x5xi32> return %r : tensor<1x2x3x4x5xi32> } + +// ----- + +// CHECK: @convert_f16_to_f32(%[[A:.+]]: tensor<16x8xf16>, %[[B:.+]]: tensor<16x8xf32>, +// CHECK-SAME: %[[C:.+]]: tensor<16x8xf32>) -> tensor<16x8xf32> { +// CHECK: {{.*}} = linalg.elementwise +// CHECK-SAME: kind=#linalg.elementwise_kind
+// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xf16>, tensor<16x8xf32>) +// CHECK-SAME: outs(%[[C]] : tensor<16x8xf32>) -> tensor<16x8xf32> +// +func.func @convert_f16_to_f32(%A: tensor<16x8xf16>, %B: tensor<16x8xf32>, + %C: tensor<16x8xf32>) -> tensor<16x8xf32> { + %r = linalg.elementwise + kind=#linalg.elementwise_kind
+ ins(%A, %B: tensor<16x8xf16>, tensor<16x8xf32>) + outs(%C: tensor<16x8xf32>) -> tensor<16x8xf32> + return %r : tensor<16x8xf32> +} + + +// ----- + +// CHECK: @explicit_cast(%[[A:.+]]: tensor<16x8xi16>, %[[B:.+]]: tensor<16x8xi32>, +// CHECK-SAME: %[[C:.+]]: tensor<16x8xi32>) -> tensor<16x8xi32> { +// CHECK: {{.*}} = linalg.elementwise +// CHECK-SAME: kind=#linalg.elementwise_kind +// CHECK-SAME: {cast = #linalg.type_fn} +// CHECK-SAME: ins(%[[A]], %[[B]] : tensor<16x8xi16>, tensor<16x8xi32>) +// CHECK-SAME: outs(%[[C]] : tensor<16x8xi32>) -> tensor<16x8xi32> +// +func.func @explicit_cast(%A: tensor<16x8xi16>, %B: tensor<16x8xi32>, %C: tensor<16x8xi32>) -> tensor<16x8xi32> { + %0 = linalg.elementwise + kind=#linalg.elementwise_kind + {cast = #linalg.type_fn} + ins(%A, %B : tensor<16x8xi16>, tensor<16x8xi32>) + outs(%C : tensor<16x8xi32>) -> tensor<16x8xi32> + return %0 : tensor<16x8xi32> +}