Skip to content

Commit 9d1a99b

Browse files
authored
Merge pull request #254 from JuliaSymbolics/myb/perf
Specialize `Base.isequal`
2 parents a8523b8 + 88bef2a commit 9d1a99b

File tree

1 file changed

+27
-17
lines changed

1 file changed

+27
-17
lines changed

src/types.jl

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,11 @@ function setmetadata(s::Symbolic, ctx::DataType, val)
102102
end
103103
end
104104

105-
Base.isequal(s::Symbolic, x) = false
106-
Base.isequal(x, s::Symbolic) = false
105+
Base.isequal(::Symbolic, x) = false
106+
Base.isequal(x, ::Symbolic) = false
107+
Base.isequal(::Symbolic, ::Symbolic) = false
107108

108-
function Base.isequal(t1::Symbolic, t2::Symbolic)
109-
t1 === t2 && return true
110-
(istree(t1) && istree(t2)) || return false
111-
a1 = arguments(t1)
112-
a2 = arguments(t2)
113109

114-
isequal(operation(t1), operation(t2)) &&
115-
length(a1) == length(a2) &&
116-
all(isequal(l,r) for (l, r) in zip(a1,a2))
117-
end
118110
### End of interface
119111

120112
function to_symbolic(x)
@@ -181,7 +173,10 @@ end
181173

182174
Base.hash(s::Sym{T}, u::UInt) where {T} = hash(T, hash(s.name, u))
183175

184-
Base.isequal(v1::Sym{T}, v2::Sym{T}) where {T} = v1 === v2
176+
function Base.isequal(a::Sym, b::Sym)
177+
symtype(a) !== symtype(b) && return false
178+
isequal(nameof(a), nameof(b))
179+
end
185180

186181
Base.show(io::IO, v::Sym) = Base.show_unquoted(io, v.name)
187182

@@ -341,6 +336,18 @@ operation(x::Term) = getfield(x, :f)
341336

342337
arguments(x::Term) = getfield(x, :arguments)
343338

339+
function Base.isequal(t1::Term, t2::Term)
340+
t1 === t2 && return true
341+
symtype(t1) !== symtype(t2) && return false
342+
343+
a1 = arguments(t1)
344+
a2 = arguments(t2)
345+
346+
isequal(operation(t1), operation(t2)) &&
347+
length(a1) == length(a2) &&
348+
all(isequal(l,r) for (l, r) in zip(a1,a2))
349+
end
350+
344351
## This is much faster than hash of an array of Any
345352
hashvec(xs, z) = foldr(hash, xs, init=z)
346353

@@ -371,7 +378,6 @@ function _promote_symtype(f, args)
371378
end
372379
end
373380

374-
375381
function term(f, args...; type = nothing)
376382
if type === nothing
377383
T = _promote_symtype(f, args)
@@ -386,7 +392,8 @@ end
386392
387393
Create a term that is similar in type to `t`. Extending this function allows packages
388394
using their own expression types with SymbolicUtils to define how new terms should
389-
be created.
395+
be created. Note that `similarterm` may return an object that has a
396+
different type than `t`, because `f` also influences the result.
390397
391398
## Arguments
392399
@@ -398,7 +405,7 @@ be created.
398405
"""
399406
similarterm(t, f, args, symtype) = f(args...)
400407
similarterm(t, f, args) = similarterm(t, f, args, _promote_symtype(f, args))
401-
similarterm(::Term, f, args, symtype=nothing) = term(f, args...; type=symtype)
408+
similarterm(t::Term, f, args) = Term{_promote_symtype(f, args)}(f, args)
402409

403410
node_count(t) = istree(t) ? reduce(+, node_count(x) for x in arguments(t), init=0) + 1 : 1
404411

@@ -873,7 +880,10 @@ function mapvalues(f, d1::AbstractDict)
873880
d
874881
end
875882

876-
function similarterm(p::Union{Mul, Add, Pow}, f, args, T=nothing)
883+
const NumericTerm = Union{Term{<:Number}, Mul{<:Number},
884+
Add{<:Number}, Pow{<:Number}}
885+
886+
function similarterm(p::NumericTerm, f, args, T=nothing)
877887
if T === nothing
878888
T = _promote_symtype(f, args)
879889
end
@@ -884,7 +894,7 @@ function similarterm(p::Union{Mul, Add, Pow}, f, args, T=nothing)
884894
elseif f == (^) && length(args) == 2
885895
Pow{T, typeof.(args)...}(args...)
886896
else
887-
f(args...)
897+
Term{T}(f, args)
888898
end
889899
end
890900

0 commit comments

Comments
 (0)