@@ -60,8 +60,6 @@ struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
60
60
model:: Model
61
61
" The VarInfo that was used"
62
62
varinfo:: AbstractVarInfo
63
- " The evaluation context that was used"
64
- context:: AbstractContext
65
63
" The values at which the model was evaluated"
66
64
params:: Vector{Tparams}
67
65
" The AD backend that was tested"
92
90
grad_atol=1e-6,
93
91
varinfo::AbstractVarInfo=link(VarInfo(model), model),
94
92
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
95
- context::AbstractContext=DefaultContext(),
96
93
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
97
94
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
98
95
verbose=true,
@@ -146,13 +143,7 @@ Everything else is optional, and can be categorised into several groups:
146
143
prep_params)`. You could then evaluate the gradient at a different set of
147
144
parameters using the `params` keyword argument.
148
145
149
- 3. _How to specify the evaluation context._
150
-
151
- A `DynamicPPL.AbstractContext` can be passed as the `context` keyword
152
- argument to control the evaluation context. This defaults to
153
- `DefaultContext()`.
154
-
155
- 4. _How to specify the results to compare against._ (Only if `test=true`.)
146
+ 3. _How to specify the results to compare against._ (Only if `test=true`.)
156
147
157
148
Once logp and its gradient has been calculated with the specified `adtype`,
158
149
it must be tested for correctness.
@@ -167,12 +158,12 @@ Everything else is optional, and can be categorised into several groups:
167
158
The default reference backend is ForwardDiff. If none of these parameters are
168
159
specified, ForwardDiff will be used to calculate the ground truth.
169
160
170
- 5 . _How to specify the tolerances._ (Only if `test=true`.)
161
+ 4 . _How to specify the tolerances._ (Only if `test=true`.)
171
162
172
163
The tolerances for the value and gradient can be set using `value_atol` and
173
164
`grad_atol`. These default to 1e-6.
174
165
175
- 6 . _Whether to output extra logging information._
166
+ 5 . _Whether to output extra logging information._
176
167
177
168
By default, this function prints messages when it runs. To silence it, set
178
169
`verbose=false`.
@@ -195,7 +186,6 @@ function run_ad(
195
186
grad_atol:: AbstractFloat = 1e-6 ,
196
187
varinfo:: AbstractVarInfo = link (VarInfo (model), model),
197
188
params:: Union{Nothing,Vector{<:AbstractFloat}} = nothing ,
198
- context:: AbstractContext = DefaultContext (),
199
189
reference_adtype:: AbstractADType = REFERENCE_ADTYPE,
200
190
expected_value_and_grad:: Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}} = nothing ,
201
191
verbose= true ,
@@ -207,7 +197,7 @@ function run_ad(
207
197
208
198
verbose && @info " Running AD on $(model. f) with $(adtype) \n "
209
199
verbose && println (" params : $(params) " )
210
- ldf = LogDensityFunction (model, varinfo, context ; adtype= adtype)
200
+ ldf = LogDensityFunction (model, varinfo; adtype= adtype)
211
201
212
202
value, grad = logdensity_and_gradient (ldf, params)
213
203
grad = collect (grad)
@@ -216,7 +206,7 @@ function run_ad(
216
206
if test
217
207
# Calculate ground truth to compare against
218
208
value_true, grad_true = if expected_value_and_grad === nothing
219
- ldf_reference = LogDensityFunction (model, varinfo, context ; adtype= reference_adtype)
209
+ ldf_reference = LogDensityFunction (model, varinfo; adtype= reference_adtype)
220
210
logdensity_and_gradient (ldf_reference, params)
221
211
else
222
212
expected_value_and_grad
@@ -245,7 +235,6 @@ function run_ad(
245
235
return ADResult (
246
236
model,
247
237
varinfo,
248
- context,
249
238
params,
250
239
adtype,
251
240
value_atol,
0 commit comments