|
1 | 1 | module Code |
2 | 2 |
|
3 | | -using StaticArrays, LabelledArrays, SparseArrays, LinearAlgebra |
| 3 | +using StaticArrays, LabelledArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions |
4 | 4 |
|
5 | 5 | export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, |
6 | 6 | SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex, |
@@ -96,7 +96,30 @@ Base.convert(::Type{Assignment}, p::Pair) = Assignment(pair[1], pair[2]) |
96 | 96 |
|
97 | 97 | toexpr(a::Assignment, st) = :($(toexpr(a.lhs, st)) = $(toexpr(a.rhs, st))) |
98 | 98 |
|
99 | | -function_to_expr(op, args, st) = nothing |
| 99 | +const NaNMathFuns = ( |
| 100 | + sin, |
| 101 | + cos, |
| 102 | + tan, |
| 103 | + asin, |
| 104 | + acos, |
| 105 | + acosh, |
| 106 | + atanh, |
| 107 | + log, |
| 108 | + log2, |
| 109 | + log10, |
| 110 | + lgamma, |
| 111 | + log1p, |
| 112 | + sqrt, |
| 113 | +) |
| 114 | +function function_to_expr(op, O, st) |
| 115 | + op in NaNMathFuns || return nothing |
| 116 | + name = nameof(op) |
| 117 | + fun = GlobalRef(NaNMath, name) |
| 118 | + args = map(Base.Fix2(toexpr, st), arguments(O)) |
| 119 | + expr = Expr(:call, fun) |
| 120 | + append!(expr.args, args) |
| 121 | + return expr |
| 122 | +end |
100 | 123 |
|
101 | 124 | function function_to_expr(op::Union{typeof(*),typeof(+)}, O, st) |
102 | 125 | out = get(st.rewrites, O, nothing) |
|
0 commit comments