Skip to content

Commit 971c739

Browse files
committed
Fix similarterm
We need `symtype(similarterm(f, args...)) === symtype(f(args...))` to make sure that terms constructed by substitution are the same as terms constructed by evaluation.
1 parent e6b0aa8 commit 971c739

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

src/methods.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ for f in monadic
9797
if f in [real]
9898
continue
9999
end
100-
@eval promote_symtype(::$(typeof(f)), T::Type{<:Number}) = Number
100+
@eval promote_symtype(::$(typeof(f)), T::Type{<:Number}) = promote_type(T, Real)
101101
@eval (::$(typeof(f)))(a::Symbolic) = term($f, a)
102102
end
103103

src/types.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,11 @@ end
309309
"""
310310
similarterm(t, f, args)
311311
312-
Create a term that is similar in type to `t`.
313-
If `t` is a `Term` will create a `Term` with the same `symtype`
314-
Otherwise simply calls `f(args...)` by default.
312+
Create a term that is similar in type to `t` such that `symtype(similarterm(f,
313+
args...)) === symtype(f(args...))`.
315314
"""
316315
similarterm(t, f, args) = f(args...)
317-
similarterm(t::Term, f, args) = Term{symtype(t)}(f, args)
316+
similarterm(::Term, f, args) = Term(f, args)
318317

319318
node_count(t) = istree(t) ? reduce(+, node_count(x) for x in arguments(t), init=0) + 1 : 1
320319

@@ -690,10 +689,12 @@ end
690689

691690
function similarterm(p::Union{Mul, Add, Pow}, f, args)
692691
if f === (+)
693-
Add(symtype(p), makeadd(1, 0, args...)...)
692+
T = rec_promote_symtype(f, map(symtype, args)...)
693+
Add(T, makeadd(1, 0, args...)...)
694694
elseif f == (*)
695-
Mul(symtype(p), makemul(1, args...)...)
696-
elseif f == (^)
695+
T = rec_promote_symtype(f, map(symtype, args)...)
696+
Mul(T, makemul(1, args...)...)
697+
elseif f == (^) && length(args) == 2
697698
Pow(args...)
698699
else
699700
f(args...)

0 commit comments

Comments
 (0)