diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index 7626d356a37f2..c64e10f534f8e 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -123,7 +123,8 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality( vector::OuterProductOp, vector::ScanOp>( [&](Operation *op) { return converter.isLegal(op); }); target.addLegalOp(); + arith::ConstantOp, arith::SelectOp, vector::SplatOp, + vector::BroadcastOp>(); } void EmulateUnsupportedFloatsPass::runOnOperation() { diff --git a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir index 99790cc45d490..fcd004ac554aa 100644 --- a/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir +++ b/mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir @@ -85,3 +85,14 @@ func.func @no_expansion(%x: f32) -> f32 { %y = arith.addf %x, %c : f32 func.return %y : f32 } + +// ----- + +func.func @no_promote_select(%c: i1, %x: bf16, %y: bf16) -> bf16 { +// CHECK-LABEL: @no_promote_select +// CHECK-SAME: (%[[C:.+]]: i1, %[[X:.+]]: bf16, %[[Y:.+]]: bf16) +// CHECK: %[[Z:.+]] = arith.select %[[C]], %[[X]], %[[Y]] : bf16 +// CHECK: return %[[Z]] + %z = arith.select %c, %x, %y : bf16 + func.return %z : bf16 +}