Skip to content

Commit 006c392

Browse files
committed
[MLIR][Linalg] Add maximumf as an binary linalg.elementwise fn
Is a proper generalisation of `linalg.max` as `linalg.elementwise allows for folding broadcasts and transposes into the op.
1 parent a8f1f1b commit 006c392

File tree

3 files changed

+21
-1
lines changed

3 files changed

+21
-1
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ def BinaryFn : I32EnumAttr<"BinaryFn", "", [
4444
I32EnumAttrCase<"min_signed", 6>,
4545
I32EnumAttrCase<"max_unsigned", 7>,
4646
I32EnumAttrCase<"min_unsigned", 8>,
47-
I32EnumAttrCase<"powf", 9>
47+
I32EnumAttrCase<"powf", 9>,
48+
I32EnumAttrCase<"maximumf", 10>
4849
]> {
4950
let genSpecializedAttr = 0;
5051
let cppNamespace = "::mlir::linalg";

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,9 @@ class RegionBuilderHelper {
590590
case BinaryFn::powf:
591591
assert(allFloatingPoint);
592592
return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
593+
case BinaryFn::maximumf:
594+
assert(allFloatingPoint);
595+
return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
593596
}
594597
if (emitError) {
595598
emitError() << "unsupported binary function";

mlir/test/python/dialects/linalg/ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,22 @@ def elementwise_op(
831831
],
832832
)
833833

834+
# CHECK: linalg.elementwise kind=#linalg.elementwise_kind<maximumf>
835+
# CHECK-SAME: indexing_maps = [#[[$VertLineBCastMap]], #[[$HorLineBCastMap]], #[[$IdentMap2D]]]
836+
# CHECK-SAME: ins(%[[VertLine]], %[[HorLine]] : tensor<8xf32>, tensor<16xf32>)
837+
# CHECK-SAME: outs(%[[OutRect]] : tensor<8x16xf32>) -> tensor<8x16xf32>
838+
linalg.elementwise(
839+
vert_line,
840+
hor_line,
841+
outs=(out_rect,),
842+
kind=linalg.ElementwiseKind.maximumf,
843+
indexing_maps=[
844+
vert_line_bcast_map,
845+
hor_line_bcast_map,
846+
ident_map_2d,
847+
],
848+
)
849+
834850
if _ops_with_non_ident_and_transposed_input_maps := True:
835851
# CHECK: %[[VertLineBoolsMem:.*]] = memref.alloca() : memref<8xi1>
836852
vert_line_bools_mem = memref.alloca(

0 commit comments

Comments
 (0)