Skip to content

Commit fc5daa0

Browse files
Merge pull request #713 from JuliaSymbolics/b/WeakKeyDict
Replace `WeakValueDict` with `WeakKeyDict` for hash consing
2 parents a5bef9a + 268a5c3 commit fc5daa0

File tree

4 files changed

+28
-21
lines changed

4 files changed

+28
-21
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ TaskLocalValues = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34"
2727
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
2828
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2929
Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415"
30-
WeakValueDicts = "897b6980-f191-5a31-bcb0-bf3c4585e0c1"
3130

3231
[weakdeps]
3332
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
@@ -61,7 +60,6 @@ TaskLocalValues = "0.1.2"
6160
TermInterface = "2.0"
6261
TimerOutputs = "0.5"
6362
Unityper = "0.1.2"
64-
WeakValueDicts = "0.1.0"
6563
julia = "1.10"
6664

6765
[extras]

src/SymbolicUtils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import TermInterface: iscall, isexpr, head, children,
2020
operation, arguments, metadata, maketerm, sorted_arguments
2121
# For ReverseDiffExt
2222
import ArrayInterface
23-
using WeakValueDicts: WeakValueDict
2423
import ExproniconLite as EL
2524
import TaskLocalValues: TaskLocalValue
2625

src/types.jl

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,39 +23,39 @@ const EMPTY_DICT_T = typeof(EMPTY_DICT)
2323
const ENABLE_HASHCONSING = Ref(true)
2424

2525
@compactify show_methods=false begin
26-
@abstract mutable struct BasicSymbolic{T} <: Symbolic{T}
26+
@abstract struct BasicSymbolic{T} <: Symbolic{T}
2727
metadata::Metadata = NO_METADATA
2828
end
29-
mutable struct Sym{T} <: BasicSymbolic{T}
29+
struct Sym{T} <: BasicSymbolic{T}
3030
name::Symbol = :OOF
3131
end
32-
mutable struct Term{T} <: BasicSymbolic{T}
32+
struct Term{T} <: BasicSymbolic{T}
3333
f::Any = identity # base/num if Pow; issorted if Add/Dict
3434
arguments::Vector{Any} = EMPTY_ARGS
3535
hash::RefValue{UInt} = EMPTY_HASH
3636
hash2::RefValue{UInt} = EMPTY_HASH
3737
end
38-
mutable struct Mul{T} <: BasicSymbolic{T}
38+
struct Mul{T} <: BasicSymbolic{T}
3939
coeff::Any = 0 # exp/den if Pow
4040
dict::EMPTY_DICT_T = EMPTY_DICT
4141
hash::RefValue{UInt} = EMPTY_HASH
4242
hash2::RefValue{UInt} = EMPTY_HASH
4343
arguments::Vector{Any} = EMPTY_ARGS
4444
end
45-
mutable struct Add{T} <: BasicSymbolic{T}
45+
struct Add{T} <: BasicSymbolic{T}
4646
coeff::Any = 0 # exp/den if Pow
4747
dict::EMPTY_DICT_T = EMPTY_DICT
4848
hash::RefValue{UInt} = EMPTY_HASH
4949
hash2::RefValue{UInt} = EMPTY_HASH
5050
arguments::Vector{Any} = EMPTY_ARGS
5151
end
52-
mutable struct Div{T} <: BasicSymbolic{T}
52+
struct Div{T} <: BasicSymbolic{T}
5353
num::Any = 1
5454
den::Any = 1
5555
simplified::Bool = false
5656
arguments::Vector{Any} = EMPTY_ARGS
5757
end
58-
mutable struct Pow{T} <: BasicSymbolic{T}
58+
struct Pow{T} <: BasicSymbolic{T}
5959
base::Any = 1
6060
exp::Any = 1
6161
arguments::Vector{Any} = EMPTY_ARGS
@@ -86,7 +86,15 @@ function exprtype(x::BasicSymbolic)
8686
end
8787
end
8888

89-
const wvd = TaskLocalValue{WeakValueDict{UInt, BasicSymbolic}}(WeakValueDict{UInt, BasicSymbolic})
89+
mutable struct HashConsingWrapper
90+
bs::BasicSymbolic
91+
end
92+
93+
Base.hash(x::HashConsingWrapper, h::UInt) = hash2(x.bs, h)
94+
95+
Base.isequal(x::HashConsingWrapper, y::HashConsingWrapper) = isequal_with_metadata(x.bs, y.bs)
96+
97+
const wkd = TaskLocalValue{WeakKeyDict{HashConsingWrapper, Nothing}}(WeakKeyDict{HashConsingWrapper, Nothing})
9098

9199
# Same but different error messages
92100
@noinline error_on_type() = error("Internal error: unreachable reached!")
@@ -522,15 +530,15 @@ Implements hash consing (flyweight design pattern) for `BasicSymbolic` objects.
522530
523531
This function checks if an equivalent `BasicSymbolic` object already exists. It uses a
524532
custom hash function (`hash2`) incorporating metadata and symtypes to search for existing
525-
objects in a `WeakValueDict` (`wvd`). Due to the possibility of hash collisions (where
533+
objects in a `WeakKeyDict` (`wkd`). Due to the possibility of hash collisions (where
526534
different objects produce the same hash), a custom equality check (`isequal_with_metadata`)
527535
which includes metadata comparison, is used to confirm the equivalence of objects with
528536
matching hashes. If an equivalent object is found, the existing object is returned;
529537
otherwise, the input `s` is returned. This reduces memory usage, improves compilation time
530538
for runtime code generation, and supports built-in common subexpression elimination,
531539
particularly when working with symbolic objects with metadata.
532540
533-
Using a `WeakValueDict` ensures that only weak references to `BasicSymbolic` objects are
541+
Using a `WeakKeyDict` ensures that only weak references to `BasicSymbolic` objects are
534542
stored, allowing objects that are no longer strongly referenced to be garbage collected.
535543
Custom functions `hash2` and `isequal_with_metadata` are used instead of `Base.hash` and
536544
`Base.isequal` to accommodate metadata without disrupting existing tests reliant on the
@@ -540,12 +548,14 @@ function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic
540548
if !ENABLE_HASHCONSING[]
541549
return s
542550
end
543-
h = hash2(s)
544-
t = get!(wvd[], h, s)
545-
if t === s || isequal_with_metadata(t, s)
546-
t
551+
cache = wkd[]
552+
hcw = HashConsingWrapper(s)
553+
k = getkey(cache, hcw, nothing)
554+
if isnothing(k)
555+
cache[hcw] = nothing
556+
return s
547557
else
548-
s
558+
return k.bs
549559
end
550560
end
551561

test/hash_consing.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,9 @@ end
133133

134134
@testset "Hashconsing can be toggled" begin
135135
SymbolicUtils.ENABLE_HASHCONSING[] = false
136-
name = gensym(:x)
137-
x1 = only(@eval @syms $name)
138-
x2 = only(@eval @syms $name)
136+
@syms a b
137+
x1 = a + b
138+
x2 = a + b
139139
@test x1 !== x2
140140
SymbolicUtils.ENABLE_HASHCONSING[] = true
141141
end

0 commit comments

Comments
 (0)