Skip to content

Commit 8962944

Browse files
YingboMashashi
andcommitted
FAST
Co-authored-by: "Shashi Gowda" <[email protected]>
1 parent a8523b8 commit 8962944

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

src/types.jl

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,27 @@ function setmetadata(s::Symbolic, ctx::DataType, val)
102102
end
103103
end
104104

105-
Base.isequal(s::Symbolic, x) = false
106-
Base.isequal(x, s::Symbolic) = false
105+
Base.isequal(::Symbolic, x) = false
106+
Base.isequal(x, ::Symbolic) = false
107+
Base.isequal(::Symbolic, ::Symbolic) = false
107108

108-
function Base.isequal(t1::Symbolic, t2::Symbolic)
109+
function Base.isequal(a::Sym, b::Sym)
110+
symtype(a) !== symtype(b) && return false
111+
isequal(nameof(a), nameof(b))
112+
end
113+
114+
function Base.isequal(t1::Term, t2::Term)
109115
t1 === t2 && return true
110-
(istree(t1) && istree(t2)) || return false
116+
symtype(t1) !== symtype(t2) && return false
117+
111118
a1 = arguments(t1)
112119
a2 = arguments(t2)
113120

114121
isequal(operation(t1), operation(t2)) &&
115122
length(a1) == length(a2) &&
116123
all(isequal(l,r) for (l, r) in zip(a1,a2))
117124
end
125+
118126
### End of interface
119127

120128
function to_symbolic(x)
@@ -386,7 +394,8 @@ end
386394
387395
Create a term that is similar in type to `t`. Extending this function allows packages
388396
using their own expression types with SymbolicUtils to define how new terms should
389-
be created.
397+
be created. Note that `similarterm` may return an object that has a
398+
different type than `t`, because `f` also influences the result.
390399
391400
## Arguments
392401
@@ -398,7 +407,7 @@ be created.
398407
"""
399408
similarterm(t, f, args, symtype) = f(args...)
400409
similarterm(t, f, args) = similarterm(t, f, args, _promote_symtype(f, args))
401-
similarterm(::Term, f, args, symtype=nothing) = term(f, args...; type=symtype)
410+
similarterm(t::Term, f, args) = Term{_promote_symtype(f, args)}(f, args)
402411

403412
node_count(t) = istree(t) ? reduce(+, node_count(x) for x in arguments(t), init=0) + 1 : 1
404413

@@ -873,7 +882,10 @@ function mapvalues(f, d1::AbstractDict)
873882
d
874883
end
875884

876-
function similarterm(p::Union{Mul, Add, Pow}, f, args, T=nothing)
885+
const NumericTerm = Union{Term{<:Number}, Mul{<:Number},
886+
Add{<:Number}, Pow{<:Number}}
887+
888+
function similarterm(p::NumericTerm, f, args, T=nothing)
877889
if T === nothing
878890
T = _promote_symtype(f, args)
879891
end
@@ -884,7 +896,7 @@ function similarterm(p::Union{Mul, Add, Pow}, f, args, T=nothing)
884896
elseif f == (^) && length(args) == 2
885897
Pow{T, typeof.(args)...}(args...)
886898
else
887-
f(args...)
899+
Term{T}(f, args)
888900
end
889901
end
890902

0 commit comments

Comments
 (0)