Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 63 additions & 30 deletions ir/instr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1729,9 +1729,9 @@ unique_ptr<Instr> ConversionOp::dup(Function &f, const string &suffix) const {

FpConversionOp::FpConversionOp(Type &type, std::string &&name, Value &val,
Op op, FpRoundingMode rm, FpExceptionMode ex,
unsigned flags)
unsigned flags, FastMathFlags fmath)
: Instr(type, std::move(name)), val(&val), op(op), rm(rm), ex(ex),
flags(flags) {
flags(flags), fmath(fmath) {
switch (op) {
case UIntToFP:
assert((flags & NNEG) == flags);
Expand All @@ -1740,6 +1740,14 @@ FpConversionOp::FpConversionOp(Type &type, std::string &&name, Value &val,
assert(flags == 0);
break;
}
switch (op) {
case FPTrunc:
case FPExt:
break;
default:
assert(fmath.isNone());
break;
}
}

vector<Value*> FpConversionOp::operands() const {
Expand Down Expand Up @@ -1774,6 +1782,8 @@ void FpConversionOp::print(ostream &os) const {
os << getName() << " = " << str;
if (flags & NNEG)
os << "nneg ";
if (!fmath.isNone())
os << fmath;
os << *val << print_type(getType(), " to ", "");
if (!rm.isDefault())
os << ", rounding=" << rm;
Expand All @@ -1783,7 +1793,8 @@ void FpConversionOp::print(ostream &os) const {

StateValue FpConversionOp::toSMT(State &s) const {
auto &v = s[*val];
function<StateValue(const expr &, const Type &, const expr&)> fn;
function<StateValue(const expr &, const Type &, const expr &)> fn;
function<StateValue(const StateValue &, const Type &, const Type &)> scalar;

switch (op) {
case SIntToFP:
Expand Down Expand Up @@ -1846,37 +1857,59 @@ StateValue FpConversionOp::toSMT(State &s) const {
break;
case FPExt:
case FPTrunc:
fn = [](auto &val, auto &to_type, auto &rm) -> StateValue {
return { val.float2Float(to_type.getAsFloatType()->getDummyFloat(), rm),
true };
scalar = [&](const StateValue &sv, const Type &from_type,
const Type &to_type) -> StateValue {
auto fm_poison_expr = [&](const expr &val, const expr &np,
const Type &ty) {
return fm_poison(
s, val, np, [](auto &x, auto &) { return x; }, ty, fmath, {}, true,
/*flags_out_only=*/true);
};
auto [input, input_np] =
fm_poison_expr(sv.value, sv.non_poison, from_type);
AndExpr np;
np.add(std::move(input_np));
function<StateValue(const expr &)> fn_rm = [&](auto &rm) -> StateValue {
return {from_type.getAsFloatType()->getFloat(input).float2Float(
to_type.getAsFloatType()->getDummyFloat(), rm),
true};
};
auto output = round_value(s, rm, np, fn_rm);
np.add(std::move(output.non_poison));
return fm_poison_expr(
to_type.getAsFloatType()->fromFloat(
s, output.value, from_type, from_type.isFloatType(), sv.value),
np(), to_type);
};
break;
}

auto scalar = [&](const StateValue &sv, const Type &from_type,
const Type &to_type) -> StateValue {
auto val = sv.value;

if (from_type.isFloatType()) {
auto ty = from_type.getAsFloatType();
val = ty->getFloat(val);
}

function<StateValue(const expr&)> fn_rm
= [&](auto &rm) { return fn(val, to_type, rm); };
AndExpr np;
np.add(sv.non_poison);
if (!scalar)
scalar = [&](const StateValue &sv, const Type &from_type,
const Type &to_type) -> StateValue {
auto val = sv.value;

StateValue ret = to_type.isFloatType() ? round_value(s, rm, np, fn_rm)
: fn(val, to_type, rm.toSMT());
np.add(std::move(ret.non_poison));
if (from_type.isFloatType()) {
auto ty = from_type.getAsFloatType();
val = ty->getFloat(val);
}

return { to_type.isFloatType()
? to_type.getAsFloatType()
->fromFloat(s, ret.value, from_type, from_type.isFloatType(),
sv.value)
: std::move(ret.value), np()};
};
function<StateValue(const expr &)> fn_rm = [&](auto &rm) {
return fn(val, to_type, rm);
};
AndExpr np;
np.add(sv.non_poison);

StateValue ret = to_type.isFloatType() ? round_value(s, rm, np, fn_rm)
: fn(val, to_type, rm.toSMT());
np.add(std::move(ret.non_poison));

return {to_type.isFloatType() ? to_type.getAsFloatType()->fromFloat(
s, ret.value, from_type,
from_type.isFloatType(), sv.value)
: std::move(ret.value),
np()};
};

if (getType().isVectorType()) {
vector<StateValue> vals;
Expand Down Expand Up @@ -1922,8 +1955,8 @@ expr FpConversionOp::getTypeConstraints(const Function &f) const {
}

unique_ptr<Instr> FpConversionOp::dup(Function &f, const string &suffix) const {
return
make_unique<FpConversionOp>(getType(), getName() + suffix, *val, op, rm);
return make_unique<FpConversionOp>(getType(), getName() + suffix, *val, op,
rm, ex, flags, fmath);
}


Expand Down
6 changes: 5 additions & 1 deletion ir/instr.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,16 +319,20 @@ class FpConversionOp final : public Instr {
FpRoundingMode rm;
FpExceptionMode ex;
unsigned flags;
FastMathFlags fmath;

public:
FpConversionOp(Type &type, std::string &&name, Value &val, Op op,
FpRoundingMode rm = {}, FpExceptionMode ex = {},
unsigned flags = None);
unsigned flags = None, FastMathFlags fmath = {});

Op getOp() const { return op; }
FpRoundingMode getRoundingMode() const { return rm; }
FpExceptionMode getExceptionMode() const { return ex; }
unsigned getFlags() const { return flags; }
FastMathFlags getFastMathFlags() const {
return fmath;
}

std::vector<Value*> operands() const override;
bool propagatesPoison() const override;
Expand Down
2 changes: 1 addition & 1 deletion llvm_util/llvm2alive.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ class llvm2alive_ : public llvm::InstVisitor<llvm2alive_, unique_ptr<Instr>> {
}
return make_unique<FpConversionOp>(*ty, value_name(i), *val, op,
FpRoundingMode{}, FpExceptionMode{},
flags);
flags, parse_fmath(i));
}

RetTy visitFreezeInst(llvm::FreezeInst &i) {
Expand Down
24 changes: 24 additions & 0 deletions tests/alive-tv/fp/fpext-trunc.srctgt.ll
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,27 @@ define half @src(half %0) {
define half @tgt(half %0) {
ret half %0
}

define half @src1(float noundef %a) {
%b = fneg nnan float %a
%c = fptrunc nnan float %b to half
ret half %c
}

define half @tgt1(float noundef %a) {
%b = fptrunc nnan float %a to half
%c = fneg nnan half %b
ret half %c
}

define half @src2(half noundef %a) {
%av = fpext nnan half %a to float
%c = fadd nnan float %av, %av
%d = fptrunc nnan float %c to half
ret half %d
}

define half @tgt2(half noundef %a) {
%b = fadd nnan half %a, %a
ret half %b
}
13 changes: 13 additions & 0 deletions tests/alive-tv/fp/fptrunc-fmf-fail.srctgt.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
define half @src(float noundef %a) {
%b = fneg ninf float %a
%c = fptrunc float %b to half
ret half %c
}

define half @tgt(float noundef %a) {
%b = fptrunc ninf float %a to half
%c = fneg ninf half %b
ret half %c
}

; ERROR: Target is more poisonous than source
Loading