Skip to content

Commit 0b0989f

Browse files
committed
Add SmallTag type
This is an alternative to `Tag` that provides largely the same functionality, but carries around only the hash of the function / array types instead of the full types themselves. This can make these types much less bulky to print and easier to visually scan for.
1 parent d813023 commit 0b0989f

File tree

2 files changed

+85
-17
lines changed

2 files changed

+85
-17
lines changed

src/config.jl

Lines changed: 80 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,50 @@ end
2020

2121
Tag(::Nothing, ::Type{V}) where {V} = nothing
2222

23-
2423
@inline function (::Type{Tag{F1,V1}}, ::Type{Tag{F2,V2}}) where {F1,V1,F2,V2}
2524
tagcount(Tag{F1,V1}) < tagcount(Tag{F2,V2})
2625
end
2726

27+
"""
28+
HashTag{Hash}
29+
30+
HashTag is similar to a Tag, but carries just a small UInt64 hash,
31+
instead of the full type, which makes stacktraces / types easier to
32+
read while still providing good resilience to perturbation confusion.
33+
"""
34+
struct HashTag{H}
35+
end
36+
37+
@generated function tagcount(::Type{HashTag{H}}) where {H}
38+
:($(Threads.atomic_add!(TAGCOUNT, UInt(1))))
39+
end
40+
41+
function HashTag(f::F, ::Type{V}) where {F,V}
42+
H = if F <: Tuple
43+
# no easy way to check Jacobian tag used with Hessians as multiple functions may be used
44+
# see checktag(::Type{Tag{FT,VT}}, f::F, x::AbstractArray{V}) where {FT<:Tuple,VT,F,V}
45+
nothing
46+
else
47+
hash(F) hash(V)
48+
end
49+
tagcount(HashTag{H}) # trigger generated function
50+
HashTag{H}()
51+
end
52+
53+
HashTag(::Nothing, ::Type{V}) where {V} = nothing
54+
55+
@inline function (::Type{HashTag{H1}}, ::Type{Tag{F2,V2}}) where {H1,F2,V2}
56+
tagcount(HashTag{H1}) < tagcount(Tag{F2,V2})
57+
end
58+
59+
@inline function (::Type{Tag{F1,V1}}, ::Type{HashTag{H2}}) where {F1,V1,H2}
60+
tagcount(Tag{F1,V1}) < tagcount(HashTag{H2})
61+
end
62+
63+
@inline function (::Type{HashTag{H1}}, ::Type{HashTag{H2}}) where {H1,H2}
64+
tagcount(HashTag{H1}) < tagcount(HashTag{H2})
65+
end
66+
2867
struct InvalidTagException{E,O} <: Exception
2968
end
3069

@@ -36,13 +75,22 @@ checktag(::Type{Tag{FT,VT}}, f::F, x::AbstractArray{V}) where {FT,VT,F,V} =
3675

3776
checktag(::Type{Tag{F,V}}, f::F, x::AbstractArray{V}) where {F,V} = true
3877

78+
# HashTag is a smaller tag, that only confirms the hash
79+
function checktag(::Type{HashTag{HT}}, f::F, x::AbstractArray{V}) where {HT,F,V}
80+
H = hash(F) hash(V)
81+
if HT == H || HT === nothing
82+
true
83+
else
84+
throw(InvalidTagException{HashTag{H},HashTag{HT}}())
85+
end
86+
end
87+
3988
# no easy way to check Jacobian tag used with Hessians as multiple functions may be used
4089
checktag(::Type{Tag{FT,VT}}, f::F, x::AbstractArray{V}) where {FT<:Tuple,VT,F,V} = true
4190

4291
# custom tag: you're on your own.
4392
checktag(z, f, x) = true
4493

45-
4694
##################
4795
# AbstractConfig #
4896
##################
@@ -55,6 +103,21 @@ Base.eltype(cfg::AbstractConfig) = eltype(typeof(cfg))
55103

56104
@inline (chunksize(::AbstractConfig{N})::Int) where {N} = N
57105

106+
@inline function maketag(f, X; style::Union{Symbol,Nothing} = nothing)
107+
if style === :hash
108+
return HashTag(f, X)
109+
elseif style === :type
110+
return Tag(f, X)
111+
elseif style === nothing
112+
if HASHTAG_MODE_ENABLED
113+
return HashTag(f, X)
114+
else
115+
return Tag(f, X)
116+
end
117+
end
118+
error("unexpected tag style: $(style)")
119+
end
120+
58121
####################
59122
# DerivativeConfig #
60123
####################
@@ -108,9 +171,9 @@ vector `x`.
108171
The returned `GradientConfig` instance contains all the work buffers required by
109172
`ForwardDiff.gradient` and `ForwardDiff.gradient!`.
110173
111-
If `f` is `nothing` instead of the actual target function, then the returned instance can
112-
be used with any target function. However, this will reduce ForwardDiff's ability to catch
113-
and prevent perturbation confusion (see https://github.com/JuliaDiff/ForwardDiff.jl/issues/83).
174+
If `f` or `tag` is `nothing`, then the returned instance can be used with any target function.
175+
However, this will reduce ForwardDiff's ability to catch and prevent perturbation confusion
176+
(see https://github.com/JuliaDiff/ForwardDiff.jl/issues/83).
114177
115178
This constructor does not store/modify `x`.
116179
"""
@@ -145,9 +208,9 @@ The returned `JacobianConfig` instance contains all the work buffers required by
145208
`ForwardDiff.jacobian` and `ForwardDiff.jacobian!` when the target function takes the form
146209
`f(x)`.
147210
148-
If `f` is `nothing` instead of the actual target function, then the returned instance can
149-
be used with any target function. However, this will reduce ForwardDiff's ability to catch
150-
and prevent perturbation confusion (see https://github.com/JuliaDiff/ForwardDiff.jl/issues/83).
211+
If `f` or `tag` is `nothing`, then the returned instance can be used with any target function.
212+
However, this will reduce ForwardDiff's ability to catch and prevent perturbation confusion
213+
(see https://github.com/JuliaDiff/ForwardDiff.jl/issues/83).
151214
152215
This constructor does not store/modify `x`.
153216
"""
@@ -170,9 +233,9 @@ The returned `JacobianConfig` instance contains all the work buffers required by
170233
`ForwardDiff.jacobian` and `ForwardDiff.jacobian!` when the target function takes the form
171234
`f!(y, x)`.
172235
173-
If `f!` is `nothing` instead of the actual target function, then the returned instance can
174-
be used with any target function. However, this will reduce ForwardDiff's ability to catch
175-
and prevent perturbation confusion (see https://github.com/JuliaDiff/ForwardDiff.jl/issues/83).
236+
If `f!` or `tag` is `nothing`, then the returned instance can be used with any target function.
237+
However, this will reduce ForwardDiff's ability to catch and prevent perturbation confusion
238+
(see https://github.com/JuliaDiff/ForwardDiff.jl/issues/83).
176239
177240
This constructor does not store/modify `y` or `x`.
178241
"""
@@ -212,9 +275,9 @@ configured for the case where the `result` argument is an `AbstractArray`. If
212275
it is a `DiffResult`, the `HessianConfig` should instead be constructed via
213276
`ForwardDiff.HessianConfig(f, result, x, chunk)`.
214277
215-
If `f` is `nothing` instead of the actual target function, then the returned instance can
216-
be used with any target function. However, this will reduce ForwardDiff's ability to catch
217-
and prevent perturbation confusion (see https://github.com/JuliaDiff/ForwardDiff.jl/issues/83).
278+
If `f` or `tag` is `nothing`, then the returned instance can be used with any target function.
279+
However, this will reduce ForwardDiff's ability to catch and prevent perturbation confusion
280+
(see https://github.com/JuliaDiff/ForwardDiff.jl/issues/83).
218281
219282
This constructor does not store/modify `x`.
220283
"""
@@ -236,9 +299,9 @@ type/shape of the input vector `x`.
236299
The returned `HessianConfig` instance contains all the work buffers required by
237300
`ForwardDiff.hessian!` for the case where the `result` argument is an `DiffResult`.
238301
239-
If `f` is `nothing` instead of the actual target function, then the returned instance can
240-
be used with any target function. However, this will reduce ForwardDiff's ability to catch
241-
and prevent perturbation confusion (see https://github.com/JuliaDiff/ForwardDiff.jl/issues/83).
302+
If `f` or `tag` is `nothing`, then the returned instance can be used with any target function.
303+
However, this will reduce ForwardDiff's ability to catch and prevent perturbation confusion
304+
(see https://github.com/JuliaDiff/ForwardDiff.jl/issues/83).
242305
243306
This constructor does not store/modify `x`.
244307
"""

src/prelude.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
const NANSAFE_MODE_ENABLED = @load_preference("nansafe_mode", false)
22
const DEFAULT_CHUNK_THRESHOLD = @load_preference("default_chunk_threshold", 12)
33

4+
# On ≤1.10, the hash of a type cannot be computed at compile-time,
5+
# making `HashTag(...)` type-unstable, so `Tag(...)` is left as
6+
# as the default.
7+
const HASHTAG_MODE_ENABLED = @load_preference("hashtag_mode", VERSION v"1.11")
8+
49
const AMBIGUOUS_TYPES = (AbstractFloat, Irrational, Integer, Rational, Real, RoundingMode)
510

611
const UNARY_PREDICATES = Symbol[:isinf, :isnan, :isfinite, :iseven, :isodd, :isreal, :isinteger]

0 commit comments

Comments
 (0)