Skip to content

Commit 5a923ea

Browse files
authored
Merge pull request #536 from JuliaSymbolics/myb/reg
Register NaNMath functions
2 parents 325b673 + 4f1540b commit 5a923ea

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

src/code.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ const NaNMathFuns = (
112112
sqrt,
113113
)
114114
function function_to_expr(op, O, st)
115-
op in NaNMathFuns || return nothing
115+
(get(st.rewrites, :nanmath, false) && op in NaNMathFuns) || return nothing
116116
name = nameof(op)
117117
fun = GlobalRef(NaNMath, name)
118118
args = map(Base.Fix2(toexpr, st), arguments(O))

src/methods.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import NaNMath
12
import SpecialFunctions: gamma, loggamma, erf, erfc, erfcinv, erfi, erfcx,
23
dawson, digamma, trigamma, invdigamma, polygamma,
34
airyai, airyaiprime, airybi, airybiprime, besselj0,
@@ -12,9 +13,12 @@ const monadic = [deg2rad, rad2deg, transpose, asind, log1p, acsch,
1213
atand, sec, acscd, cot, exp2, expm1, atanh, gamma,
1314
loggamma, erf, erfc, erfcinv, erfi, erfcx, dawson, digamma,
1415
trigamma, invdigamma, polygamma, airyai, airyaiprime, airybi,
15-
airybiprime, besselj0, besselj1, bessely0, bessely1, isfinite]
16+
airybiprime, besselj0, besselj1, bessely0, bessely1, isfinite,
17+
NaNMath.sin, NaNMath.cos, NaNMath.tan, NaNMath.asin, NaNMath.acos,
18+
NaNMath.acosh, NaNMath.atanh, NaNMath.log, NaNMath.log2,
19+
NaNMath.log10, NaNMath.lgamma, NaNMath.log1p, NaNMath.sqrt]
1620

17-
const diadic = [max, min, hypot, atan, mod, rem, copysign,
21+
const diadic = [max, min, hypot, atan, NaNMath.atanh, mod, rem, copysign,
1822
besselj, bessely, besseli, besselk, hankelh1, hankelh2,
1923
polygamma, beta, logbeta]
2024
const previously_declared_for = Set([])

test/code.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Test, SymbolicUtils
2+
using NaNMath
23
using SymbolicUtils.Code
34
using SymbolicUtils.Code: LazyState
45
using StaticArrays
@@ -7,6 +8,8 @@ using SparseArrays
78
using LinearAlgebra
89

910
test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linenums!(b))
11+
nanmath_st = Code.NameState()
12+
nanmath_st.rewrites[:nanmath] = true
1013

1114
@testset "Code" begin
1215
@syms a b c d e p q t x(t) y(t) z(t)
@@ -83,6 +86,13 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
8386
end)
8487
@test toexpr(SetArray(true, a, [x(t), AtIndex(9, b), c])).head == :macrocall
8588

89+
f = GlobalRef(NaNMath, :sin)
90+
test_repr(toexpr(LiteralExpr(:(let x=1, y=2
91+
$(sin(a+b))
92+
end)), nanmath_st),
93+
:(let x = 1, y = 2
94+
$(f)($(+)(a, b))
95+
end))
8696
test_repr(toexpr(LiteralExpr(:(let x=1, y=2
8797
$(sin(a+b))
8898
end))),
@@ -190,7 +200,7 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
190200
@test f(1) == 1
191201
@test f(2) == 2
192202

193-
f = eval(toexpr(Func([a, b], [], sqrt(a - b))))
203+
f = eval(toexpr(Func([a, b], [], sqrt(a - b)), nanmath_st))
194204
@test isnan(f(0, 10))
195205
@test f(10, 2) sqrt(8)
196206
end

0 commit comments

Comments
 (0)