diff --git a/ir/instr.cpp b/ir/instr.cpp index f1ea89cc7..84ca923b5 100644 --- a/ir/instr.cpp +++ b/ir/instr.cpp @@ -1729,9 +1729,9 @@ unique_ptr ConversionOp::dup(Function &f, const string &suffix) const { FpConversionOp::FpConversionOp(Type &type, std::string &&name, Value &val, Op op, FpRoundingMode rm, FpExceptionMode ex, - unsigned flags) + unsigned flags, FastMathFlags fmath) : Instr(type, std::move(name)), val(&val), op(op), rm(rm), ex(ex), - flags(flags) { + flags(flags), fmath(fmath) { switch (op) { case UIntToFP: assert((flags & NNEG) == flags); @@ -1740,6 +1740,14 @@ FpConversionOp::FpConversionOp(Type &type, std::string &&name, Value &val, assert(flags == 0); break; } + switch (op) { + case FPTrunc: + case FPExt: + break; + default: + assert(fmath.isNone()); + break; + } } vector FpConversionOp::operands() const { @@ -1774,6 +1782,8 @@ void FpConversionOp::print(ostream &os) const { os << getName() << " = " << str; if (flags & NNEG) os << "nneg "; + if (!fmath.isNone()) + os << fmath; os << *val << print_type(getType(), " to ", ""); if (!rm.isDefault()) os << ", rounding=" << rm; @@ -1783,7 +1793,8 @@ void FpConversionOp::print(ostream &os) const { StateValue FpConversionOp::toSMT(State &s) const { auto &v = s[*val]; - function fn; + function fn; + function scalar; switch (op) { case SIntToFP: @@ -1846,37 +1857,59 @@ StateValue FpConversionOp::toSMT(State &s) const { break; case FPExt: case FPTrunc: - fn = [](auto &val, auto &to_type, auto &rm) -> StateValue { - return { val.float2Float(to_type.getAsFloatType()->getDummyFloat(), rm), - true }; + scalar = [&](const StateValue &sv, const Type &from_type, + const Type &to_type) -> StateValue { + auto fm_poison_expr = [&](const expr &val, const expr &np, + const Type &ty) { + return fm_poison( + s, val, np, [](auto &x, auto &) { return x; }, ty, fmath, {}, true, + /*flags_out_only=*/true); + }; + auto [input, input_np] = + fm_poison_expr(sv.value, sv.non_poison, from_type); + AndExpr np; + np.add(std::move(input_np)); + function fn_rm = [&](auto &rm) -> StateValue { + return {from_type.getAsFloatType()->getFloat(input).float2Float( + to_type.getAsFloatType()->getDummyFloat(), rm), + true}; + }; + auto output = round_value(s, rm, np, fn_rm); + np.add(std::move(output.non_poison)); + return fm_poison_expr( + to_type.getAsFloatType()->fromFloat( + s, output.value, from_type, from_type.isFloatType(), sv.value), + np(), to_type); }; break; } - auto scalar = [&](const StateValue &sv, const Type &from_type, - const Type &to_type) -> StateValue { - auto val = sv.value; - - if (from_type.isFloatType()) { - auto ty = from_type.getAsFloatType(); - val = ty->getFloat(val); - } - - function fn_rm - = [&](auto &rm) { return fn(val, to_type, rm); }; - AndExpr np; - np.add(sv.non_poison); + if (!scalar) + scalar = [&](const StateValue &sv, const Type &from_type, + const Type &to_type) -> StateValue { + auto val = sv.value; - StateValue ret = to_type.isFloatType() ? round_value(s, rm, np, fn_rm) - : fn(val, to_type, rm.toSMT()); - np.add(std::move(ret.non_poison)); + if (from_type.isFloatType()) { + auto ty = from_type.getAsFloatType(); + val = ty->getFloat(val); + } - return { to_type.isFloatType() - ? to_type.getAsFloatType() - ->fromFloat(s, ret.value, from_type, from_type.isFloatType(), - sv.value) - : std::move(ret.value), np()}; - }; + function fn_rm = [&](auto &rm) { + return fn(val, to_type, rm); + }; + AndExpr np; + np.add(sv.non_poison); + + StateValue ret = to_type.isFloatType() ? round_value(s, rm, np, fn_rm) + : fn(val, to_type, rm.toSMT()); + np.add(std::move(ret.non_poison)); + + return {to_type.isFloatType() ? to_type.getAsFloatType()->fromFloat( + s, ret.value, from_type, + from_type.isFloatType(), sv.value) + : std::move(ret.value), + np()}; + }; if (getType().isVectorType()) { vector vals; @@ -1922,8 +1955,8 @@ expr FpConversionOp::getTypeConstraints(const Function &f) const { } unique_ptr FpConversionOp::dup(Function &f, const string &suffix) const { - return - make_unique(getType(), getName() + suffix, *val, op, rm); + return make_unique(getType(), getName() + suffix, *val, op, + rm, ex, flags, fmath); } diff --git a/ir/instr.h b/ir/instr.h index 146cb97ed..302f4fa8f 100644 --- a/ir/instr.h +++ b/ir/instr.h @@ -319,16 +319,20 @@ class FpConversionOp final : public Instr { FpRoundingMode rm; FpExceptionMode ex; unsigned flags; + FastMathFlags fmath; public: FpConversionOp(Type &type, std::string &&name, Value &val, Op op, FpRoundingMode rm = {}, FpExceptionMode ex = {}, - unsigned flags = None); + unsigned flags = None, FastMathFlags fmath = {}); Op getOp() const { return op; } FpRoundingMode getRoundingMode() const { return rm; } FpExceptionMode getExceptionMode() const { return ex; } unsigned getFlags() const { return flags; } + FastMathFlags getFastMathFlags() const { + return fmath; + } std::vector operands() const override; bool propagatesPoison() const override; diff --git a/llvm_util/llvm2alive.cpp b/llvm_util/llvm2alive.cpp index ec9ebe6f0..792f02159 100644 --- a/llvm_util/llvm2alive.cpp +++ b/llvm_util/llvm2alive.cpp @@ -316,7 +316,7 @@ class llvm2alive_ : public llvm::InstVisitor> { } return make_unique(*ty, value_name(i), *val, op, FpRoundingMode{}, FpExceptionMode{}, - flags); + flags, parse_fmath(i)); } RetTy visitFreezeInst(llvm::FreezeInst &i) { diff --git a/tests/alive-tv/fp/fpext-trunc.srctgt.ll b/tests/alive-tv/fp/fpext-trunc.srctgt.ll index 994a46d72..7fe8a0d82 100644 --- a/tests/alive-tv/fp/fpext-trunc.srctgt.ll +++ b/tests/alive-tv/fp/fpext-trunc.srctgt.ll @@ -7,3 +7,27 @@ define half @src(half %0) { define half @tgt(half %0) { ret half %0 } + +define half @src1(float noundef %a) { + %b = fneg nnan float %a + %c = fptrunc nnan float %b to half + ret half %c +} + +define half @tgt1(float noundef %a) { + %b = fptrunc nnan float %a to half + %c = fneg nnan half %b + ret half %c +} + +define half @src2(half noundef %a) { + %av = fpext nnan half %a to float + %c = fadd nnan float %av, %av + %d = fptrunc nnan float %c to half + ret half %d +} + +define half @tgt2(half noundef %a) { + %b = fadd nnan half %a, %a + ret half %b +} diff --git a/tests/alive-tv/fp/fptrunc-fmf-fail.srctgt.ll b/tests/alive-tv/fp/fptrunc-fmf-fail.srctgt.ll new file mode 100644 index 000000000..a46c9fbb8 --- /dev/null +++ b/tests/alive-tv/fp/fptrunc-fmf-fail.srctgt.ll @@ -0,0 +1,13 @@ +define half @src(float noundef %a) { + %b = fneg ninf float %a + %c = fptrunc float %b to half + ret half %c +} + +define half @tgt(float noundef %a) { + %b = fptrunc ninf float %a to half + %c = fneg ninf half %b + ret half %c +} + +; ERROR: Target is more poisonous than source