Skip to content

Commit e0963ff

Browse files
Merge pull request #807 from JuliaSymbolics/as/fix-bugs
fix: fix stochastic hashing
2 parents 403629b + f9d8198 commit e0963ff

File tree

1 file changed

+42
-6
lines changed

1 file changed

+42
-6
lines changed

src/types.jl

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -967,11 +967,11 @@ function isequal_bsimpl(a::BSImpl.Type{T}, b::BSImpl.Type{T}, full::Bool) where
967967
(BSImpl.AddMul(; coeff = c1, dict = d1, variant = v1, shape = s1, type = t1), BSImpl.AddMul(; coeff = c2, dict = d2, variant = v2, shape = s2, type = t2)) => begin
968968
isequal_somescalar(c1, c2) && (!full || (typeof(c1) === typeof(c2))) && isequal_addmuldict(d1, d2, full) && isequal(v1, v2) && s1 == s2 && t1 === t2
969969
end
970-
(BSImpl.Div(; num = n1, den = d1, type = t1), BSImpl.Div(; num = n2, den = d2, type = t2)) => begin
971-
isequal_bsimpl(n1, n2, full) && isequal_bsimpl(d1, d2, full) && t1 === t2
970+
(BSImpl.Div(; num = n1, den = d1, type = t1, shape = s1), BSImpl.Div(; num = n2, den = d2, type = t2, shape = s2)) => begin
971+
isequal_bsimpl(n1, n2, full) && isequal_bsimpl(d1, d2, full) && s1 == s2 && t1 === t2
972972
end
973973
(BSImpl.ArrayOp(; output_idx = o1, expr = e1, reduce = f1, term = t1, ranges = r1, shape = s1, type = type1), BSImpl.ArrayOp(; output_idx = o2, expr = e2, reduce = f2, term = t2, ranges = r2, shape = s2, type = type2)) => begin
974-
isequal(o1, o2) && isequal(e1, e2) && isequal(f1, f2)::Bool && isequal(t1, t2) && isequal_rangesdict(r1, r2, full) && s1 == s2 && t1 === t2
974+
isequal(o1, o2) && isequal(e1, e2) && isequal(f1, f2)::Bool && isequal(t1, t2) && isequal_rangesdict(r1, r2, full) && s1 == s2 && type1 === type2
975975
end
976976
end
977977
if full && partial && !(Ta <: BSImpl.Const)
@@ -1054,6 +1054,41 @@ function hash_rangesdict(d::RangesT, h::UInt, full::Bool)
10541054
return hash(hv, h)
10551055
end
10561056

1057+
"""
1058+
$METHODLIST
1059+
1060+
Custom hash functions for `vartype(x)`, since hashes of types defined in a module are not
1061+
stable across machines or processes.
1062+
"""
1063+
vartype_hash(::Type{SymReal}, h::UInt) = hash(0x3fffc14710d3391a, h)
1064+
vartype_hash(::Type{SafeReal}, h::UInt) = hash(0x0e8c1e3ac836f40d, h)
1065+
vartype_hash(::Type{TreeReal}, h::UInt) = hash(0x44ec30357ff75155, h)
1066+
1067+
"""
1068+
$TYPEDSIGNATURES
1069+
1070+
Custom hash functions for `AddMul.variant`, since it falls back to the `Base.Enum`
1071+
implementation, which uses `objectid`, which changes across runs.
1072+
"""
1073+
hash_addmulvariant(x::AddMulVariant.T, h::UInt) = hash(x === AddMulVariant.ADD ? 0x6d86258fc9cc0742 : 0x5e0a17a14cd8c815, h)
1074+
1075+
const FNTYPE_SEED = 0x8b414291138f6c45
1076+
1077+
"""
1078+
$TYPEDSIGNATURES
1079+
1080+
Custom hash function for a type that may be an `FnType`, since hashes of types defined in a module are not
1081+
stable across machines or processes.
1082+
"""
1083+
function hash_maybe_fntype(T::TypeT, h::UInt)
1084+
@nospecialize T
1085+
if T <: FnType
1086+
hash(T.parameters[1], hash(T.parameters[2], hash(T.parameters[3], h)::UInt)::UInt)::UInt FNTYPE_SEED
1087+
else
1088+
hash(T, h)::UInt
1089+
end
1090+
end
1091+
10571092
"""
10581093
hash_bsimpl(s::BSImpl.Type{T}, h::UInt, full) where {T}
10591094
@@ -1064,7 +1099,7 @@ function hash_bsimpl(s::BSImpl.Type{T}, h::UInt, full) where {T}
10641099
if !iszero(h)
10651100
return hash(hash_bsimpl(s, zero(h), full), h)::UInt
10661101
end
1067-
h = hash(T, h)
1102+
h = vartype_hash(T, h)
10681103

10691104
partial::UInt = @match s begin
10701105
BSImpl.Const(; val, hash) => begin
@@ -1083,7 +1118,7 @@ function hash_bsimpl(s::BSImpl.Type{T}, h::UInt, full) where {T}
10831118
!full && !iszero(hash) && return hash
10841119
h = Base.hash(name, h)
10851120
h = Base.hash(shape, h)
1086-
h = Base.hash(type, h)
1121+
h = hash_maybe_fntype(type, h)
10871122
h SYM_SALT
10881123
end
10891124
BSImpl.Term(; f, args, shape, hash, hash2, type) => begin
@@ -1094,7 +1129,7 @@ function hash_bsimpl(s::BSImpl.Type{T}, h::UInt, full) where {T}
10941129
BSImpl.AddMul(; coeff, dict, variant, shape, type, hash, hash2) => begin
10951130
full && !iszero(hash2) && return hash2
10961131
!full && !iszero(hash) && return hash
1097-
htmp = hash_somescalar(coeff, hash_addmuldict(dict, Base.hash(variant, Base.hash(shape, Base.hash(type, h))), full))
1132+
htmp = hash_somescalar(coeff, hash_addmuldict(dict, hash_addmulvariant(variant, Base.hash(shape, Base.hash(type, h))), full))
10981133
if full
10991134
htmp = Base.hash(typeof(coeff), htmp)
11001135
end
@@ -1136,6 +1171,7 @@ Base.one( s::BSImpl.Type) = one( symtype(s))
11361171
Return a `Const` symbolic wrapping `1`.
11371172
"""
11381173
Base.one(::Type{BSImpl.Type{T}}) where {T} = one_of_vartype(T)
1174+
Base.oneunit(::Type{BSImpl.Type{T}}) where {T} = one_of_vartype(T)
11391175
"""
11401176
$TYPEDSIGNATURES
11411177

0 commit comments

Comments
 (0)