Skip to content

Commit 6d163fd

Browse files
committed
Tweak interface
1 parent c5574ae commit 6d163fd

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

HISTORY.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
### AD testing utilities
88

99
`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default.
10+
To disable this, pass the `linked=false` keyword argument.
1011

1112
### VarInfo constructors
1213

src/test_utils/ad.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@ end
6464
benchmark=false,
6565
value_atol=1e-6,
6666
grad_atol=1e-6,
67-
varinfo::AbstractVarInfo=link(VarInfo(model), model),
68-
params::Vector{<:Real}=varinfo[:],
67+
linked::Bool=true,
68+
varinfo::AbstractVarInfo=VarInfo(model),
69+
params::Union{Nothing,Vector{<:Real}}=nothing,
6970
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
7071
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
7172
verbose=true,
@@ -96,10 +97,12 @@ Everything else is optional, and can be categorised into several groups:
9697
DynamicPPL contains several different types of VarInfo objects which change
9798
the way model evaluation occurs. If you want to use a specific type of
9899
VarInfo, pass it as the `varinfo` argument. Otherwise, it will default to
99-
using a `TypedVarInfo` generated from the model. It will also perform
100-
_linking_, that is, the parameters in the VarInfo will be transformed to
101-
unconstrained Euclidean space if they aren't already in that space. Note
102-
that the act of linking may change the length of the parameters.
100+
using a `TypedVarInfo` generated from the model.
101+
102+
It will also perform _linking_, that is, the parameters in the VarInfo will
103+
be transformed to unconstrained Euclidean space if they aren't already in
104+
that space. Note that the act of linking may change the length of the
105+
parameters. To disable linking, set `linked=false`.
103106
104107
2. _How to specify the parameters._
105108
@@ -151,14 +154,22 @@ function run_ad(
151154
benchmark=false,
152155
value_atol=1e-6,
153156
grad_atol=1e-6,
154-
varinfo::AbstractVarInfo=link(VarInfo(model), model),
155-
params::Vector{<:Real}=varinfo[:],
157+
linked::Bool=true,
158+
varinfo::AbstractVarInfo=VarInfo(model),
159+
params::Union{Nothing,Vector{<:Real}}=nothing,
156160
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
157161
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
158162
verbose=true,
159163
)::ADResult
160-
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
164+
if linked
165+
varinfo = link(varinfo, model)
166+
end
167+
if isnothing(params)
168+
params = varinfo[:]
169+
end
161170
params = map(identity, params)
171+
172+
verbose && @info "Running AD on $(model.f) with $(adtype)\n"
162173
verbose && println(" params : $(params)")
163174
ldf = LogDensityFunction(model, varinfo; adtype=adtype)
164175

0 commit comments

Comments
 (0)