Skip to content

Commit d7107af

Browse files
authored
Merge pull request #165 from JuliaSymbolics/myb/fixes
Terms constructed by substitution should be the same as terms constructed by evaluation
2 parents 20f777b + 62efff0 commit d7107af

File tree

2 files changed

+29
-23
lines changed

2 files changed

+29
-23
lines changed

src/methods.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ for f in monadic
9797
if f in [real]
9898
continue
9999
end
100-
@eval promote_symtype(::$(typeof(f)), T::Type{<:Number}) = Number
100+
@eval promote_symtype(::$(typeof(f)), T::Type{<:Number}) = promote_type(T, Real)
101101
@eval (::$(typeof(f)))(a::Symbolic) = term($f, a)
102102
end
103103

@@ -134,8 +134,12 @@ for (f, Domain) in [(==) => Number, (!=) => Number,
134134
end
135135
end
136136

137-
Base.:!(s::Symbolic{Bool}) = Term{Bool}(!, [s])
138-
Base.:~(s::Symbolic{Bool}) = Term{Bool}(!, [s])
137+
for f in [!, ~]
138+
@eval begin
139+
promote_symtype(::$(typeof(f)), ::Type{<:Bool}) = Bool
140+
(::$(typeof(f)))(s::Symbolic{Bool}) = Term{Bool}(!, [s])
141+
end
142+
end
139143

140144

141145
# An ifelse node, ifelse is a built-in unfortunately

src/types.jl

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,17 @@ symtype(::Symbolic{T}) where {T} = T
5858

5959
Base.isequal(s::Symbolic, x) = false
6060
Base.isequal(x, s::Symbolic) = false
61-
Base.isequal(x::Symbolic, y::Symbolic) = false
61+
62+
function Base.isequal(t1::Symbolic, t2::Symbolic)
63+
t1 === t2 && return true
64+
(istree(t1) && istree(t2)) || return false
65+
a1 = arguments(t1)
66+
a2 = arguments(t2)
67+
68+
isequal(operation(t1), operation(t2)) &&
69+
length(a1) == length(a2) &&
70+
all(isequal(l,r) for (l, r) in zip(a1,a2))
71+
end
6272
### End of interface
6373

6474
"""
@@ -155,7 +165,7 @@ function (f::Sym)(args...)
155165
end
156166

157167
"""
158-
`promote_symtype(f::Sym{FnType{X,Y}}, arg_symtypes...)`
168+
promote_symtype(f::Sym{FnType{X,Y}}, arg_symtypes...)
159169
160170
The output symtype of applying variable `f` to arugments of symtype `arg_symtypes...`.
161171
if the arguments are of the wrong type then this function will error.
@@ -287,18 +297,9 @@ function Base.hash(t::Term{T}, salt::UInt) where {T}
287297
hashvec(arguments(t), hash(operation(t), hash(T, salt)))
288298
end
289299

290-
function Base.isequal(t1::Term, t2::Term)
291-
t1 === t2 && return true
292-
a1 = arguments(t1)
293-
a2 = arguments(t2)
294-
295-
isequal(operation(t1), operation(t2)) && length(a1) == length(a2) &&
296-
all(isequal(l,r) for (l, r) in zip(a1,a2))
297-
end
298-
299300
function term(f, args...; type = nothing)
300301
if type === nothing
301-
T = rec_promote_symtype(f, symtype.(args)...)
302+
T = rec_promote_symtype(f, map(symtype, args)...)
302303
else
303304
T = type
304305
end
@@ -308,12 +309,11 @@ end
308309
"""
309310
similarterm(t, f, args)
310311
311-
Create a term that is similar in type to `t`.
312-
If `t` is a `Term` will create a `Term` with the same `symtype`
313-
Otherwise simply calls `f(args...)` by default.
312+
Create a term that is similar in type to `t` such that `symtype(similarterm(f,
313+
args...)) === symtype(f(args...))`.
314314
"""
315315
similarterm(t, f, args) = f(args...)
316-
similarterm(t::Term, f, args) = Term{symtype(t)}(f, args)
316+
similarterm(::Term, f, args) = term(f, args...)
317317

318318
node_count(t) = istree(t) ? reduce(+, node_count(x) for x in arguments(t), init=0) + 1 : 1
319319

@@ -618,7 +618,7 @@ struct Pow{X, B, E} <: Symbolic{X}
618618
exp::E
619619
end
620620

621-
function Pow(a,b)
621+
function Pow(a, b)
622622
_iszero(b) && return 1
623623
_isone(b) && return a
624624
Pow{promote_symtype(^, symtype(a), symtype(b)), typeof(a), typeof(b)}(a,b)
@@ -689,10 +689,12 @@ end
689689

690690
function similarterm(p::Union{Mul, Add, Pow}, f, args)
691691
if f === (+)
692-
Add(symtype(p), makeadd(1, 0, args...)...)
692+
T = rec_promote_symtype(f, map(symtype, args)...)
693+
Add(T, makeadd(1, 0, args...)...)
693694
elseif f == (*)
694-
Mul(symtype(p), makemul(1, args...)...)
695-
elseif f == (^)
695+
T = rec_promote_symtype(f, map(symtype, args)...)
696+
Mul(T, makemul(1, args...)...)
697+
elseif f == (^) && length(args) == 2
696698
Pow(args...)
697699
else
698700
f(args...)

0 commit comments

Comments
 (0)