Skip to content

Commit 59e3aa6

Browse files
Merge pull request #673 from Blablablanca/hash-consing2
Apply hash consing to all `BasicSymbolic` subtypes
2 parents 1af55d4 + d716d3b commit 59e3aa6

File tree

2 files changed

+138
-10
lines changed

2 files changed

+138
-10
lines changed

src/types.jl

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple
9797
# Call outer constructor because hash consing cannot be applied in inner constructor
9898
@compactified obj::BasicSymbolic begin
9999
Sym => Sym{T}(nt_new.name; nt_new...)
100+
Term => Term{T}(nt_new.f, nt_new.arguments; nt_new...)
101+
Add => Add(T, nt_new.coeff, nt_new.dict; nt_new...)
102+
Mul => Mul(T, nt_new.coeff, nt_new.dict; nt_new...)
103+
Div => Div{T}(nt_new.num, nt_new.den, nt_new.simplified; nt_new...)
104+
Pow => Pow{T}(nt_new.base, nt_new.exp; nt_new...)
100105
_ => Unityper.rt_constructor(obj){T}(;nt_new...)
101106
end
102107
end
@@ -298,6 +303,7 @@ Base.nameof(s::BasicSymbolic) = issym(s) ? s.name : error("None Sym BasicSymboli
298303

299304
## This is much faster than hash of an array of Any
300305
hashvec(xs, z) = foldr(hash, xs, init=z)
306+
hashvec2(xs, z) = foldr(hash2, xs, init=z)
301307
const SYM_SALT = 0x4de7d7c66d41da43 % UInt
302308
const ADD_SALT = 0xaddaddaddaddadda % UInt
303309
const SUB_SALT = 0xaaaaaaaaaaaaaaaa % UInt
@@ -344,10 +350,43 @@ objects. Unlike `Base.hash`, which only considers the expression structure, `has
344350
includes the metadata and symtype in the hash calculation. This can be beneficial for hash
345351
consing, allowing for more effective deduplication of symbolically equivalent expressions
346352
with different metadata or symtypes.
353+
354+
Equivalent numbers of different types, such as `0.5::Float64` and
355+
`(1 // 2)::Rational{Int64}`, have the same default `Base.hash` value. The `hash2` function
356+
distinguishes these by including their numeric types in the hash calculation to ensure that
357+
symbolically equivalent expressions with different numeric types are treated as distinct
358+
objects.
347359
"""
360+
hash2(s, salt::UInt) = hash(s, salt)
361+
function hash2(n::T, salt::UInt) where {T <: Number}
362+
hash(T, hash(n, salt))
363+
end
348364
hash2(s::BasicSymbolic) = hash2(s, zero(UInt))
349365
function hash2(s::BasicSymbolic{T}, salt::UInt)::UInt where {T}
350-
hash(metadata(s), hash(T, hash(s, salt)))
366+
E = exprtype(s)
367+
h::UInt = 0
368+
if E === SYM
369+
h = hash(nameof(s), salt SYM_SALT)
370+
elseif E === ADD || E === MUL
371+
hashoffset = isadd(s) ? ADD_SALT : SUB_SALT
372+
hv = Base.hasha_seed
373+
for (k, v) in s.dict
374+
hv ⊻= hash2(k, hash(v))
375+
end
376+
h = hash(hv, salt)
377+
h = hash(hashoffset, hash2(s.coeff, h))
378+
elseif E === DIV
379+
h = hash2(s.num, hash2(s.den, salt DIV_SALT))
380+
elseif E === POW
381+
h = hash2(s.exp, hash2(s.base, salt POW_SALT))
382+
elseif E === TERM
383+
op = operation(s)
384+
oph = op isa Function ? nameof(op) : op
385+
h = hashvec2(arguments(s), hash(oph, salt))
386+
else
387+
error_on_type()
388+
end
389+
hash(metadata(s), hash(T, h))
351390
end
352391

353392
###
@@ -395,7 +434,8 @@ function Term{T}(f, args; kw...) where T
395434
args = convert(Vector{Any}, args)
396435
end
397436

398-
Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), kw...)
437+
s = Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), kw...)
438+
BasicSymbolic(s)
399439
end
400440

401441
function Term(f, args; metadata=NO_METADATA)
@@ -415,7 +455,8 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T
415455
end
416456
end
417457

418-
Add{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
458+
s = Add{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
459+
BasicSymbolic(s)
419460
end
420461

421462
function Mul(T, a, b; metadata=NO_METADATA, kw...)
@@ -430,7 +471,8 @@ function Mul(T, a, b; metadata=NO_METADATA, kw...)
430471
else
431472
coeff = a
432473
dict = b
433-
Mul{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
474+
s = Mul{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...)
475+
BasicSymbolic(s)
434476
end
435477
end
436478

@@ -461,7 +503,7 @@ function maybe_intcoeff(x)
461503
end
462504
end
463505

464-
function Div{T}(n, d, simplified=false; metadata=nothing) where {T}
506+
function Div{T}(n, d, simplified=false; metadata=nothing, kwargs...) where {T}
465507
if T<:Number && !(T<:SafeReal)
466508
n, d = quick_cancel(n, d)
467509
end
@@ -495,7 +537,8 @@ function Div{T}(n, d, simplified=false; metadata=nothing) where {T}
495537
end
496538
end
497539

498-
Div{T}(; num=n, den=d, simplified, arguments=[], metadata)
540+
s = Div{T}(; num=n, den=d, simplified, arguments=[], metadata)
541+
BasicSymbolic(s)
499542
end
500543

501544
function Div(n,d, simplified=false; kw...)
@@ -509,14 +552,15 @@ end
509552

510553
@inline denominators(x) = isdiv(x) ? numerators(x.den) : Any[1]
511554

512-
function Pow{T}(a, b; metadata=NO_METADATA) where {T}
555+
function Pow{T}(a, b; metadata=NO_METADATA, kwargs...) where {T}
513556
_iszero(b) && return 1
514557
_isone(b) && return a
515-
Pow{T}(; base=a, exp=b, arguments=[], metadata)
558+
s = Pow{T}(; base=a, exp=b, arguments=[], metadata)
559+
BasicSymbolic(s)
516560
end
517561

518-
function Pow(a, b; metadata=NO_METADATA)
519-
Pow{promote_symtype(^, symtype(a), symtype(b))}(makepow(a, b)..., metadata=metadata)
562+
function Pow(a, b; metadata = NO_METADATA, kwargs...)
563+
Pow{promote_symtype(^, symtype(a), symtype(b))}(makepow(a, b)...; metadata, kwargs...)
520564
end
521565

522566
function toterm(t::BasicSymbolic{T}) where T

test/hash_consing.jl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using SymbolicUtils, Test
2+
using SymbolicUtils: Term, Add, Mul, Div, Pow, hash2
23

34
struct Ctx1 end
45
struct Ctx2 end
@@ -24,3 +25,86 @@ struct Ctx2 end
2425
xm3 = setmetadata(x1, Ctx2, "meta_2")
2526
@test xm1 !== xm3
2627
end
28+
29+
@syms a b c
30+
31+
@testset "Term" begin
32+
t1 = sin(a)
33+
t2 = sin(a)
34+
@test t1 === t2
35+
t3 = Term(identity,[a])
36+
t4 = Term(identity,[a])
37+
@test t3 === t4
38+
t5 = Term{Int}(identity,[a])
39+
@test t3 !== t5
40+
tm1 = setmetadata(t1, Ctx1, "meta_1")
41+
@test t1 !== tm1
42+
end
43+
44+
@testset "Add" begin
45+
d1 = a + b
46+
d2 = b + a
47+
@test d1 === d2
48+
d3 = b - 2 + a
49+
d4 = a + b - 2
50+
@test d3 === d4
51+
d5 = Add(Int, 0, Dict(a => 1, b => 1))
52+
@test d5 !== d1
53+
54+
dm1 = setmetadata(d1,Ctx1,"meta_1")
55+
@test d1 !== dm1
56+
end
57+
58+
@testset "Mul" begin
59+
m1 = a*b
60+
m2 = b*a
61+
@test m1 === m2
62+
m3 = 6*a*b
63+
m4 = 3*a*2*b
64+
@test m3 === m4
65+
m5 = Mul(Int, 1, Dict(a => 1, b => 1))
66+
@test m5 !== m1
67+
68+
mm1 = setmetadata(m1, Ctx1, "meta_1")
69+
@test m1 !== mm1
70+
end
71+
72+
@testset "Div" begin
73+
v1 = a/b
74+
v2 = a/b
75+
@test v1 === v2
76+
v3 = -1/a
77+
v4 = -1/a
78+
@test v3 === v4
79+
v5 = 3a/6
80+
v6 = 2a/4
81+
@test v5 === v6
82+
v7 = Div{Float64}(-1,a)
83+
@test v7 !== v3
84+
85+
vm1 = setmetadata(v1,Ctx1, "meta_1")
86+
@test vm1 !== v1
87+
end
88+
89+
@testset "Pow" begin
90+
p1 = a^b
91+
p2 = a^b
92+
@test p1 === p2
93+
p3 = a^(2^-b)
94+
p4 = a^(2^-b)
95+
@test p3 === p4
96+
p5 = Pow{Float64}(a,b)
97+
@test p1 !== p5
98+
99+
pm1 = setmetadata(p1,Ctx1, "meta_1")
100+
@test pm1 !== p1
101+
end
102+
103+
@testset "Equivalent numbers" begin
104+
f = 0.5
105+
r = 1 // 2
106+
@test hash(f) == hash(r)
107+
u0 = zero(UInt)
108+
@test hash2(f, u0) != hash2(r, u0)
109+
@test f + a !== r + a
110+
end

0 commit comments

Comments
 (0)