@@ -34,11 +34,18 @@ function setchunksize(chunk_size::Int)
34
34
end
35
35
36
36
abstract type ADBackend end
37
- struct ForwardDiffAD{chunk} <: ADBackend end
37
+ struct ForwardDiffAD{chunk,standardtag} <: ADBackend end
38
+
39
+ # Use standard tag if not specified otherwise
40
+ ForwardDiffAD {N} () where {N} = ForwardDiffAD {N,true} ()
41
+
38
42
getchunksize (:: Type{<:ForwardDiffAD{chunk}} ) where chunk = chunk
39
43
getchunksize (:: Type{<:Sampler{Talg}} ) where Talg = getchunksize (Talg)
40
44
getchunksize (:: Type{SampleFromPrior} ) = CHUNKSIZE[]
41
45
46
+ standardtag (:: ForwardDiffAD{<:Any,true} ) = true
47
+ standardtag (:: ForwardDiffAD ) = false
48
+
42
49
struct TrackerAD <: ADBackend end
43
50
struct ZygoteAD <: ADBackend end
44
51
@@ -95,59 +102,54 @@ Compute the value of the log joint of `θ` and its gradient for the model
95
102
specified by `(vi, sampler, model)` using `backend` for AD, e.g. `ForwardDiffAD{N}()` uses `ForwardDiff.jl` with chunk size `N`, `TrackerAD()` uses `Tracker.jl` and `ZygoteAD()` uses `Zygote.jl`.
96
103
"""
97
104
function gradient_logp (
98
- :: ForwardDiffAD ,
105
+ ad :: ForwardDiffAD ,
99
106
θ:: AbstractVector{<:Real} ,
100
107
vi:: VarInfo ,
101
108
model:: Model ,
102
109
sampler:: AbstractSampler = SampleFromPrior (),
103
- ctx :: DynamicPPL.AbstractContext = DynamicPPL. DefaultContext ()
110
+ context :: DynamicPPL.AbstractContext = DynamicPPL. DefaultContext ()
104
111
)
105
- # Define function to compute log joint.
106
- logp_old = getlogp (vi)
107
- function f (θ)
108
- new_vi = VarInfo (vi, sampler, θ)
109
- new_vi = last (DynamicPPL. evaluate!! (model, new_vi, sampler, ctx))
110
- logp = getlogp (new_vi)
111
- # Don't need to capture the resulting `vi` since this is only
112
- # needed if `vi` is mutable.
113
- setlogp!! (vi, ForwardDiff. value (logp))
114
- return logp
115
- end
112
+ # Define log density function.
113
+ f = Turing. LogDensityFunction (vi, model, sampler, context)
116
114
117
- # Set chunk size and do ForwardMode.
118
- chunk_size = getchunksize (typeof (sampler))
115
+ # Define configuration for ForwardDiff.
116
+ tag = if standardtag (ad)
117
+ ForwardDiff. Tag (Turing. TuringTag (), eltype (θ))
118
+ else
119
+ ForwardDiff. Tag (f, eltype (θ))
120
+ end
121
+ chunk_size = getchunksize (typeof (ad))
119
122
config = if chunk_size == 0
120
- ForwardDiff. GradientConfig (f, θ)
123
+ ForwardDiff. GradientConfig (f, θ, ForwardDiff . Chunk (θ), tag )
121
124
else
122
- ForwardDiff. GradientConfig (f, θ, ForwardDiff. Chunk (length (θ), chunk_size))
125
+ ForwardDiff. GradientConfig (f, θ, ForwardDiff. Chunk (length (θ), chunk_size), tag )
123
126
end
124
- ∂l∂θ = ForwardDiff. gradient! (similar (θ), f, θ, config)
125
- l = getlogp (vi)
126
- setlogp!! (vi, logp_old)
127
127
128
- return l, ∂l∂θ
128
+ # Obtain both value and gradient of the log density function.
129
+ out = DiffResults. GradientResult (θ)
130
+ ForwardDiff. gradient! (out, f, θ, config)
131
+ logp = DiffResults. value (out)
132
+ ∂logp∂θ = DiffResults. gradient (out)
133
+
134
+ return logp, ∂logp∂θ
129
135
end
130
136
function gradient_logp (
131
137
:: TrackerAD ,
132
138
θ:: AbstractVector{<:Real} ,
133
139
vi:: VarInfo ,
134
140
model:: Model ,
135
141
sampler:: AbstractSampler = SampleFromPrior (),
136
- ctx :: DynamicPPL.AbstractContext = DynamicPPL. DefaultContext ()
142
+ context :: DynamicPPL.AbstractContext = DynamicPPL. DefaultContext ()
137
143
)
138
- T = typeof (getlogp (vi))
139
-
140
- # Specify objective function.
141
- function f (θ)
142
- new_vi = VarInfo (vi, sampler, θ)
143
- new_vi = last (DynamicPPL. evaluate!! (model, new_vi, sampler, ctx))
144
- return getlogp (new_vi)
145
- end
144
+ # Define log density function.
145
+ f = Turing. LogDensityFunction (vi, model, sampler, context)
146
146
147
- # Compute forward and reverse passes .
147
+ # Compute forward pass and pullback .
148
148
l_tracked, ȳ = Tracker. forward (f, θ)
149
- # Remove tracking info from variables in model (because mutable state).
150
- l:: T , ∂l∂θ:: typeof (θ) = Tracker. data (l_tracked), Tracker. data (ȳ (1 )[1 ])
149
+
150
+ # Remove tracking info.
151
+ l:: typeof (getlogp (vi)) = Tracker. data (l_tracked)
152
+ ∂l∂θ:: typeof (θ) = Tracker. data (only (ȳ (1 )))
151
153
152
154
return l, ∂l∂θ
153
155
end
@@ -160,18 +162,12 @@ function gradient_logp(
160
162
sampler:: AbstractSampler = SampleFromPrior (),
161
163
context:: DynamicPPL.AbstractContext = DynamicPPL. DefaultContext ()
162
164
)
163
- T = typeof (getlogp (vi))
164
-
165
- # Specify objective function.
166
- function f (θ)
167
- new_vi = VarInfo (vi, sampler, θ)
168
- new_vi = last (DynamicPPL. evaluate!! (model, new_vi, sampler, context))
169
- return getlogp (new_vi)
170
- end
165
+ # Define log density function.
166
+ f = Turing. LogDensityFunction (vi, model, sampler, context)
171
167
172
- # Compute forward and reverse passes .
173
- l:: T , ȳ = ZygoteRules. pullback (f, θ)
174
- ∂l∂θ:: typeof (θ) = ȳ (1 )[ 1 ]
168
+ # Compute forward pass and pullback .
169
+ l:: typeof ( getlogp (vi)) , ȳ = ZygoteRules. pullback (f, θ)
170
+ ∂l∂θ:: typeof (θ) = only ( ȳ (1 ))
175
171
176
172
return l, ∂l∂θ
177
173
end
0 commit comments