Skip to content

Commit 9ffc75d

Browse files
authored
Add FMF support for fptrunc/fpext (#1131)
1 parent 4eff96c commit 9ffc75d

File tree

5 files changed

+106
-32
lines changed

5 files changed

+106
-32
lines changed

ir/instr.cpp

Lines changed: 63 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1729,9 +1729,9 @@ unique_ptr<Instr> ConversionOp::dup(Function &f, const string &suffix) const {
17291729

17301730
FpConversionOp::FpConversionOp(Type &type, std::string &&name, Value &val,
17311731
Op op, FpRoundingMode rm, FpExceptionMode ex,
1732-
unsigned flags)
1732+
unsigned flags, FastMathFlags fmath)
17331733
: Instr(type, std::move(name)), val(&val), op(op), rm(rm), ex(ex),
1734-
flags(flags) {
1734+
flags(flags), fmath(fmath) {
17351735
switch (op) {
17361736
case UIntToFP:
17371737
assert((flags & NNEG) == flags);
@@ -1740,6 +1740,14 @@ FpConversionOp::FpConversionOp(Type &type, std::string &&name, Value &val,
17401740
assert(flags == 0);
17411741
break;
17421742
}
1743+
switch (op) {
1744+
case FPTrunc:
1745+
case FPExt:
1746+
break;
1747+
default:
1748+
assert(fmath.isNone());
1749+
break;
1750+
}
17431751
}
17441752

17451753
vector<Value*> FpConversionOp::operands() const {
@@ -1774,6 +1782,8 @@ void FpConversionOp::print(ostream &os) const {
17741782
os << getName() << " = " << str;
17751783
if (flags & NNEG)
17761784
os << "nneg ";
1785+
if (!fmath.isNone())
1786+
os << fmath;
17771787
os << *val << print_type(getType(), " to ", "");
17781788
if (!rm.isDefault())
17791789
os << ", rounding=" << rm;
@@ -1783,7 +1793,8 @@ void FpConversionOp::print(ostream &os) const {
17831793

17841794
StateValue FpConversionOp::toSMT(State &s) const {
17851795
auto &v = s[*val];
1786-
function<StateValue(const expr &, const Type &, const expr&)> fn;
1796+
function<StateValue(const expr &, const Type &, const expr &)> fn;
1797+
function<StateValue(const StateValue &, const Type &, const Type &)> scalar;
17871798

17881799
switch (op) {
17891800
case SIntToFP:
@@ -1846,37 +1857,59 @@ StateValue FpConversionOp::toSMT(State &s) const {
18461857
break;
18471858
case FPExt:
18481859
case FPTrunc:
1849-
fn = [](auto &val, auto &to_type, auto &rm) -> StateValue {
1850-
return { val.float2Float(to_type.getAsFloatType()->getDummyFloat(), rm),
1851-
true };
1860+
scalar = [&](const StateValue &sv, const Type &from_type,
1861+
const Type &to_type) -> StateValue {
1862+
auto fm_poison_expr = [&](const expr &val, const expr &np,
1863+
const Type &ty) {
1864+
return fm_poison(
1865+
s, val, np, [](auto &x, auto &) { return x; }, ty, fmath, {}, true,
1866+
/*flags_out_only=*/true);
1867+
};
1868+
auto [input, input_np] =
1869+
fm_poison_expr(sv.value, sv.non_poison, from_type);
1870+
AndExpr np;
1871+
np.add(std::move(input_np));
1872+
function<StateValue(const expr &)> fn_rm = [&](auto &rm) -> StateValue {
1873+
return {from_type.getAsFloatType()->getFloat(input).float2Float(
1874+
to_type.getAsFloatType()->getDummyFloat(), rm),
1875+
true};
1876+
};
1877+
auto output = round_value(s, rm, np, fn_rm);
1878+
np.add(std::move(output.non_poison));
1879+
return fm_poison_expr(
1880+
to_type.getAsFloatType()->fromFloat(
1881+
s, output.value, from_type, from_type.isFloatType(), sv.value),
1882+
np(), to_type);
18521883
};
18531884
break;
18541885
}
18551886

1856-
auto scalar = [&](const StateValue &sv, const Type &from_type,
1857-
const Type &to_type) -> StateValue {
1858-
auto val = sv.value;
1859-
1860-
if (from_type.isFloatType()) {
1861-
auto ty = from_type.getAsFloatType();
1862-
val = ty->getFloat(val);
1863-
}
1864-
1865-
function<StateValue(const expr&)> fn_rm
1866-
= [&](auto &rm) { return fn(val, to_type, rm); };
1867-
AndExpr np;
1868-
np.add(sv.non_poison);
1887+
if (!scalar)
1888+
scalar = [&](const StateValue &sv, const Type &from_type,
1889+
const Type &to_type) -> StateValue {
1890+
auto val = sv.value;
18691891

1870-
StateValue ret = to_type.isFloatType() ? round_value(s, rm, np, fn_rm)
1871-
: fn(val, to_type, rm.toSMT());
1872-
np.add(std::move(ret.non_poison));
1892+
if (from_type.isFloatType()) {
1893+
auto ty = from_type.getAsFloatType();
1894+
val = ty->getFloat(val);
1895+
}
18731896

1874-
return { to_type.isFloatType()
1875-
? to_type.getAsFloatType()
1876-
->fromFloat(s, ret.value, from_type, from_type.isFloatType(),
1877-
sv.value)
1878-
: std::move(ret.value), np()};
1879-
};
1897+
function<StateValue(const expr &)> fn_rm = [&](auto &rm) {
1898+
return fn(val, to_type, rm);
1899+
};
1900+
AndExpr np;
1901+
np.add(sv.non_poison);
1902+
1903+
StateValue ret = to_type.isFloatType() ? round_value(s, rm, np, fn_rm)
1904+
: fn(val, to_type, rm.toSMT());
1905+
np.add(std::move(ret.non_poison));
1906+
1907+
return {to_type.isFloatType() ? to_type.getAsFloatType()->fromFloat(
1908+
s, ret.value, from_type,
1909+
from_type.isFloatType(), sv.value)
1910+
: std::move(ret.value),
1911+
np()};
1912+
};
18801913

18811914
if (getType().isVectorType()) {
18821915
vector<StateValue> vals;
@@ -1922,8 +1955,8 @@ expr FpConversionOp::getTypeConstraints(const Function &f) const {
19221955
}
19231956

19241957
unique_ptr<Instr> FpConversionOp::dup(Function &f, const string &suffix) const {
1925-
return
1926-
make_unique<FpConversionOp>(getType(), getName() + suffix, *val, op, rm);
1958+
return make_unique<FpConversionOp>(getType(), getName() + suffix, *val, op,
1959+
rm, ex, flags, fmath);
19271960
}
19281961

19291962

ir/instr.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,16 +319,20 @@ class FpConversionOp final : public Instr {
319319
FpRoundingMode rm;
320320
FpExceptionMode ex;
321321
unsigned flags;
322+
FastMathFlags fmath;
322323

323324
public:
324325
FpConversionOp(Type &type, std::string &&name, Value &val, Op op,
325326
FpRoundingMode rm = {}, FpExceptionMode ex = {},
326-
unsigned flags = None);
327+
unsigned flags = None, FastMathFlags fmath = {});
327328

328329
Op getOp() const { return op; }
329330
FpRoundingMode getRoundingMode() const { return rm; }
330331
FpExceptionMode getExceptionMode() const { return ex; }
331332
unsigned getFlags() const { return flags; }
333+
FastMathFlags getFastMathFlags() const {
334+
return fmath;
335+
}
332336

333337
std::vector<Value*> operands() const override;
334338
bool propagatesPoison() const override;

llvm_util/llvm2alive.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ class llvm2alive_ : public llvm::InstVisitor<llvm2alive_, unique_ptr<Instr>> {
316316
}
317317
return make_unique<FpConversionOp>(*ty, value_name(i), *val, op,
318318
FpRoundingMode{}, FpExceptionMode{},
319-
flags);
319+
flags, parse_fmath(i));
320320
}
321321

322322
RetTy visitFreezeInst(llvm::FreezeInst &i) {

tests/alive-tv/fp/fpext-trunc.srctgt.ll

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,27 @@ define half @src(half %0) {
77
define half @tgt(half %0) {
88
ret half %0
99
}
10+
11+
define half @src1(float noundef %a) {
12+
%b = fneg nnan float %a
13+
%c = fptrunc nnan float %b to half
14+
ret half %c
15+
}
16+
17+
define half @tgt1(float noundef %a) {
18+
%b = fptrunc nnan float %a to half
19+
%c = fneg nnan half %b
20+
ret half %c
21+
}
22+
23+
define half @src2(half noundef %a) {
24+
%av = fpext nnan half %a to float
25+
%c = fadd nnan float %av, %av
26+
%d = fptrunc nnan float %c to half
27+
ret half %d
28+
}
29+
30+
define half @tgt2(half noundef %a) {
31+
%b = fadd nnan half %a, %a
32+
ret half %b
33+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
define half @src(float noundef %a) {
2+
%b = fneg ninf float %a
3+
%c = fptrunc float %b to half
4+
ret half %c
5+
}
6+
7+
define half @tgt(float noundef %a) {
8+
%b = fptrunc ninf float %a to half
9+
%c = fneg ninf half %b
10+
ret half %c
11+
}
12+
13+
; ERROR: Target is more poisonous than source

0 commit comments

Comments
 (0)