Skip to content

Commit 7ef8014

Browse files
authored
Merge pull request #535 from JuliaSymbolics/myb/nanmath
Use NaNMath lowering by default
2 parents 3dc99d4 + 382a88f commit 7ef8014

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

src/code.jl

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module Code
22

3-
using StaticArrays, LabelledArrays, SparseArrays, LinearAlgebra
3+
using StaticArrays, LabelledArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions
44

55
export toexpr, Assignment, (), Let, Func, DestructuredArgs, LiteralExpr,
66
SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex,
@@ -96,7 +96,30 @@ Base.convert(::Type{Assignment}, p::Pair) = Assignment(pair[1], pair[2])
9696

9797
toexpr(a::Assignment, st) = :($(toexpr(a.lhs, st)) = $(toexpr(a.rhs, st)))
9898

99-
function_to_expr(op, args, st) = nothing
99+
const NaNMathFuns = (
100+
sin,
101+
cos,
102+
tan,
103+
asin,
104+
acos,
105+
acosh,
106+
atanh,
107+
log,
108+
log2,
109+
log10,
110+
lgamma,
111+
log1p,
112+
sqrt,
113+
)
114+
function function_to_expr(op, O, st)
115+
op in NaNMathFuns || return nothing
116+
name = nameof(op)
117+
fun = GlobalRef(NaNMath, name)
118+
args = map(Base.Fix2(toexpr, st), arguments(O))
119+
expr = Expr(:call, fun)
120+
append!(expr.args, args)
121+
return expr
122+
end
100123

101124
function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st)
102125
out = get(st.rewrites, O, nothing)

test/code.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,10 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
189189
f = eval(toexpr(Func([a+b], [], a+b)))
190190
@test f(1) == 1
191191
@test f(2) == 2
192+
193+
f = eval(toexpr(Func([a, b], [], sqrt(a - b))))
194+
@test isnan(f(0, 10))
195+
@test f(10, 2) sqrt(8)
192196
end
193197

194198
let

0 commit comments

Comments
 (0)