Skip to content

Commit 9951638

Browse files
authored
Improve stacktraces by using custom tag for ForwardDiff (#1841)
* Improve stacktraces by using custom tag for ForwardDiff * Fix typo * Additional fixes * Simplify code and define `LogDensityFunction` * A bit simpler * Add tests
1 parent 8adfa22 commit 9951638

File tree

7 files changed

+104
-78
lines changed

7 files changed

+104
-78
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.21.5"
3+
version = "0.21.6"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -11,6 +11,7 @@ AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
1111
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
1212
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
1313
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
14+
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1415
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1516
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
1617
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
@@ -42,6 +43,7 @@ AdvancedVI = "0.1"
4243
BangBang = "0.3"
4344
Bijectors = "0.8, 0.9, 0.10"
4445
DataStructures = "0.18"
46+
DiffResults = "1"
4547
Distributions = "0.23.3, 0.24, 0.25"
4648
DistributionsAD = "0.6"
4749
DocStringExtensions = "0.8"

src/Turing.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,25 @@ function setprogress!(progress::Bool)
2525
return progress
2626
end
2727

28+
# Log density function
29+
struct LogDensityFunction{V,M,S,C}
30+
varinfo::V
31+
model::M
32+
sampler::S
33+
context::C
34+
end
35+
36+
function (f::LogDensityFunction)(θ::AbstractVector)
37+
return getlogp(last(DynamicPPL.evaluate!!(f.model, VarInfo(f.varinfo, f.sampler, θ), f.sampler, f.context)))
38+
end
39+
40+
# Standard tag: Improves stacktraces
41+
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
42+
struct TuringTag end
43+
44+
# Allow Turing tag in gradient etc. calls of the log density function
45+
ForwardDiff.checktag(::Type{ForwardDiff.Tag{TuringTag, V}}, ::LogDensityFunction, ::AbstractArray{V}) where {V} = true
46+
2847
# Random probability measures.
2948
include("stdlib/distributions.jl")
3049
include("stdlib/RandomMeasures.jl")

src/essential/Essential.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using StatsFuns: logsumexp, softmax
1616
using Requires
1717

1818
import AdvancedPS
19+
import DiffResults
1920
import ZygoteRules
2021

2122
include("container.jl")

src/essential/ad.jl

Lines changed: 41 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,18 @@ function setchunksize(chunk_size::Int)
3434
end
3535

3636
abstract type ADBackend end
37-
struct ForwardDiffAD{chunk} <: ADBackend end
37+
struct ForwardDiffAD{chunk,standardtag} <: ADBackend end
38+
39+
# Use standard tag if not specified otherwise
40+
ForwardDiffAD{N}() where {N} = ForwardDiffAD{N,true}()
41+
3842
getchunksize(::Type{<:ForwardDiffAD{chunk}}) where chunk = chunk
3943
getchunksize(::Type{<:Sampler{Talg}}) where Talg = getchunksize(Talg)
4044
getchunksize(::Type{SampleFromPrior}) = CHUNKSIZE[]
4145

46+
standardtag(::ForwardDiffAD{<:Any,true}) = true
47+
standardtag(::ForwardDiffAD) = false
48+
4249
struct TrackerAD <: ADBackend end
4350
struct ZygoteAD <: ADBackend end
4451

@@ -95,59 +102,54 @@ Compute the value of the log joint of `θ` and its gradient for the model
95102
specified by `(vi, sampler, model)` using `backend` for AD, e.g. `ForwardDiffAD{N}()` uses `ForwardDiff.jl` with chunk size `N`, `TrackerAD()` uses `Tracker.jl` and `ZygoteAD()` uses `Zygote.jl`.
96103
"""
97104
function gradient_logp(
98-
::ForwardDiffAD,
105+
ad::ForwardDiffAD,
99106
θ::AbstractVector{<:Real},
100107
vi::VarInfo,
101108
model::Model,
102109
sampler::AbstractSampler=SampleFromPrior(),
103-
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
110+
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
104111
)
105-
# Define function to compute log joint.
106-
logp_old = getlogp(vi)
107-
function f(θ)
108-
new_vi = VarInfo(vi, sampler, θ)
109-
new_vi = last(DynamicPPL.evaluate!!(model, new_vi, sampler, ctx))
110-
logp = getlogp(new_vi)
111-
# Don't need to capture the resulting `vi` since this is only
112-
# needed if `vi` is mutable.
113-
setlogp!!(vi, ForwardDiff.value(logp))
114-
return logp
115-
end
112+
# Define log density function.
113+
f = Turing.LogDensityFunction(vi, model, sampler, context)
116114

117-
# Set chunk size and do ForwardMode.
118-
chunk_size = getchunksize(typeof(sampler))
115+
# Define configuration for ForwardDiff.
116+
tag = if standardtag(ad)
117+
ForwardDiff.Tag(Turing.TuringTag(), eltype(θ))
118+
else
119+
ForwardDiff.Tag(f, eltype(θ))
120+
end
121+
chunk_size = getchunksize(typeof(ad))
119122
config = if chunk_size == 0
120-
ForwardDiff.GradientConfig(f, θ)
123+
ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(θ), tag)
121124
else
122-
ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size))
125+
ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size), tag)
123126
end
124-
∂l∂θ = ForwardDiff.gradient!(similar(θ), f, θ, config)
125-
l = getlogp(vi)
126-
setlogp!!(vi, logp_old)
127127

128-
return l, ∂l∂θ
128+
# Obtain both value and gradient of the log density function.
129+
out = DiffResults.GradientResult(θ)
130+
ForwardDiff.gradient!(out, f, θ, config)
131+
logp = DiffResults.value(out)
132+
∂logp∂θ = DiffResults.gradient(out)
133+
134+
return logp, ∂logp∂θ
129135
end
130136
function gradient_logp(
131137
::TrackerAD,
132138
θ::AbstractVector{<:Real},
133139
vi::VarInfo,
134140
model::Model,
135141
sampler::AbstractSampler = SampleFromPrior(),
136-
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
142+
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
137143
)
138-
T = typeof(getlogp(vi))
139-
140-
# Specify objective function.
141-
function f(θ)
142-
new_vi = VarInfo(vi, sampler, θ)
143-
new_vi = last(DynamicPPL.evaluate!!(model, new_vi, sampler, ctx))
144-
return getlogp(new_vi)
145-
end
144+
# Define log density function.
145+
f = Turing.LogDensityFunction(vi, model, sampler, context)
146146

147-
# Compute forward and reverse passes.
147+
# Compute forward pass and pullback.
148148
l_tracked, ȳ = Tracker.forward(f, θ)
149-
# Remove tracking info from variables in model (because mutable state).
150-
l::T, ∂l∂θ::typeof(θ) = Tracker.data(l_tracked), Tracker.data((1)[1])
149+
150+
# Remove tracking info.
151+
l::typeof(getlogp(vi)) = Tracker.data(l_tracked)
152+
∂l∂θ::typeof(θ) = Tracker.data(only((1)))
151153

152154
return l, ∂l∂θ
153155
end
@@ -160,18 +162,12 @@ function gradient_logp(
160162
sampler::AbstractSampler = SampleFromPrior(),
161163
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
162164
)
163-
T = typeof(getlogp(vi))
164-
165-
# Specify objective function.
166-
function f(θ)
167-
new_vi = VarInfo(vi, sampler, θ)
168-
new_vi = last(DynamicPPL.evaluate!!(model, new_vi, sampler, context))
169-
return getlogp(new_vi)
170-
end
165+
# Define log density function.
166+
f = Turing.LogDensityFunction(vi, model, sampler, context)
171167

172-
# Compute forward and reverse passes.
173-
l::T, ȳ = ZygoteRules.pullback(f, θ)
174-
∂l∂θ::typeof(θ) = (1)[1]
168+
# Compute forward pass and pullback.
169+
l::typeof(getlogp(vi)), ȳ = ZygoteRules.pullback(f, θ)
170+
∂l∂θ::typeof(θ) = only((1))
175171

176172
return l, ∂l∂θ
177173
end

src/essential/compat/reversediff.jl

Lines changed: 14 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using .ReverseDiff: compile, GradientTape
2-
using .ReverseDiff.DiffResults: GradientResult
32

43
struct ReverseDiffAD{cache} <: ADBackend end
54
const RDCache = Ref(false)
@@ -22,26 +21,20 @@ function gradient_logp(
2221
sampler::AbstractSampler = SampleFromPrior(),
2322
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
2423
)
25-
T = typeof(getlogp(vi))
24+
# Define log density function.
25+
f = Turing.LogDensityFunction(vi, model, sampler, context)
2626

27-
# Specify objective function.
28-
function f(θ)
29-
new_vi = VarInfo(vi, sampler, θ)
30-
model(new_vi, sampler, context)
31-
return getlogp(new_vi)
32-
end
27+
# Obtain both value and gradient of the log density function.
3328
tp, result = taperesult(f, θ)
3429
ReverseDiff.gradient!(result, tp, θ)
35-
l = DiffResults.value(result)
36-
l∂θ::typeof(θ) = DiffResults.gradient(result)
30+
logp = DiffResults.value(result)
31+
logp∂θ = DiffResults.gradient(result)
3732

38-
return l, ∂l∂θ
33+
return logp, ∂logp∂θ
3934
end
4035

4136
tape(f, x) = GradientTape(f, x)
42-
function taperesult(f, x)
43-
return tape(f, x), GradientResult(x)
44-
end
37+
taperesult(f, x) = (tape(f, x), DiffResults.GradientResult(x))
4538

4639
@require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @eval begin
4740
setrdcache(::Val{true}) = RDCache[] = true
@@ -58,20 +51,16 @@ end
5851
sampler::AbstractSampler = SampleFromPrior(),
5952
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
6053
)
61-
T = typeof(getlogp(vi))
54+
# Define log density function.
55+
f = Turing.LogDensityFunction(vi, model, sampler, context)
6256

63-
# Specify objective function.
64-
function f(θ)
65-
new_vi = VarInfo(vi, sampler, θ)
66-
model(new_vi, sampler, context)
67-
return getlogp(new_vi)
68-
end
57+
# Obtain both value and gradient of the log density function.
6958
ctp, result = memoized_taperesult(f, θ)
7059
ReverseDiff.gradient!(result, ctp, θ)
71-
l = DiffResults.value(result)
72-
l∂θ = DiffResults.gradient(result)
60+
logp = DiffResults.value(result)
61+
logp∂θ = DiffResults.gradient(result)
7362

74-
return l, ∂l∂θ
63+
return logp, ∂logp∂θ
7564
end
7665

7766
# This makes sure we generate a single tape per Turing model and sampler
@@ -85,7 +74,7 @@ end
8574
end
8675
memoized_taperesult(f, x) = memoized_taperesult(RDTapeKey(f, x))
8776
Memoization.@memoize Dict function memoized_taperesult(k::RDTapeKey)
88-
return compiledtape(k.f, k.x), GradientResult(k.x)
77+
return compiledtape(k.f, k.x), DiffResults.GradientResult(k.x)
8978
end
9079
compiledtape(f, x) = compile(GradientTape(f, x))
9180
end

test/essential/ad.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,4 +177,16 @@
177177
@test Turing.CHUNKSIZE[] == 0
178178
@test Turing.AdvancedVI.CHUNKSIZE[] == 0
179179
end
180+
181+
@testset "tag" begin
182+
@test Turing.ADBackend(Val(:forwarddiff))() === Turing.ForwardDiffAD{Turing.CHUNKSIZE[],true}()
183+
for chunksize in (0, 1, 10)
184+
ad = Turing.ForwardDiffAD{chunksize}()
185+
@test ad === Turing.ForwardDiffAD{chunksize,true}()
186+
@test Turing.Essential.standardtag(ad)
187+
for standardtag in (false, 0, 1)
188+
@test !Turing.Essential.standardtag(Turing.ForwardDiffAD{chunksize,standardtag}())
189+
end
190+
end
191+
end
180192
end

test/test_utils/ad_utils.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,20 @@ function test_model_ad(model, f, syms::Vector{Symbol})
8383
end
8484
end
8585

86-
spl = SampleFromPrior()
87-
_, ∇E = gradient_logp(ForwardDiffAD{1}(), vi[spl], vi, model)
88-
grad_Turing = sort(∇E)
86+
# Compute primal.
87+
x = vec(vnvals)
88+
logp = f(x)
8989

90-
# Call ForwardDiff's AD
91-
grad_FWAD = sort(ForwardDiff.gradient(f, vec(vnvals)))
90+
# Call ForwardDiff's AD directly.
91+
grad_FWAD = sort(ForwardDiff.gradient(f, x))
9292

93-
# Compare result
94-
@test grad_Turing grad_FWAD atol=1e-9
93+
# Compare with `gradient_logp`.
94+
z = vi[SampleFromPrior()]
95+
for chunksize in (0, 1, 10), standardtag in (true, false, 0, 3)
96+
l, ∇E = gradient_logp(ForwardDiffAD{chunksize, standardtag}(), z, vi, model)
97+
98+
# Compare result
99+
@test l logp
100+
@test sort(∇E) grad_FWAD atol=1e-9
101+
end
95102
end

0 commit comments

Comments
 (0)