|
64 | 64 | benchmark=false,
|
65 | 65 | value_atol=1e-6,
|
66 | 66 | grad_atol=1e-6,
|
67 |
| - varinfo::AbstractVarInfo=link(VarInfo(model), model), |
68 |
| - params::Vector{<:Real}=varinfo[:], |
| 67 | + linked::Bool=true, |
| 68 | + varinfo::AbstractVarInfo=VarInfo(model), |
| 69 | + params::Union{Nothing,Vector{<:Real}}=nothing, |
69 | 70 | reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
|
70 | 71 | expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
|
71 | 72 | verbose=true,
|
@@ -96,10 +97,12 @@ Everything else is optional, and can be categorised into several groups:
|
96 | 97 | DynamicPPL contains several different types of VarInfo objects which change
|
97 | 98 | the way model evaluation occurs. If you want to use a specific type of
|
98 | 99 | VarInfo, pass it as the `varinfo` argument. Otherwise, it will default to
|
99 |
| - using a `TypedVarInfo` generated from the model. It will also perform |
100 |
| - _linking_, that is, the parameters in the VarInfo will be transformed to |
101 |
| - unconstrained Euclidean space if they aren't already in that space. Note |
102 |
| - that the act of linking may change the length of the parameters. |
| 100 | + using a `TypedVarInfo` generated from the model. |
| 101 | +
|
| 102 | + It will also perform _linking_, that is, the parameters in the VarInfo will |
| 103 | + be transformed to unconstrained Euclidean space if they aren't already in |
| 104 | + that space. Note that the act of linking may change the length of the |
| 105 | + parameters. To disable linking, set `linked=false`. |
103 | 106 |
|
104 | 107 | 2. _How to specify the parameters._
|
105 | 108 |
|
@@ -151,14 +154,22 @@ function run_ad(
|
151 | 154 | benchmark=false,
|
152 | 155 | value_atol=1e-6,
|
153 | 156 | grad_atol=1e-6,
|
154 |
| - varinfo::AbstractVarInfo=link(VarInfo(model), model), |
155 |
| - params::Vector{<:Real}=varinfo[:], |
| 157 | + linked::Bool=true, |
| 158 | + varinfo::AbstractVarInfo=VarInfo(model), |
| 159 | + params::Union{Nothing,Vector{<:Real}}=nothing, |
156 | 160 | reference_adtype::AbstractADType=REFERENCE_ADTYPE,
|
157 | 161 | expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
|
158 | 162 | verbose=true,
|
159 | 163 | )::ADResult
|
160 |
| - verbose && @info "Running AD on $(model.f) with $(adtype)\n" |
| 164 | + if linked |
| 165 | + varinfo = link(varinfo, model) |
| 166 | + end |
| 167 | + if isnothing(params) |
| 168 | + params = varinfo[:] |
| 169 | + end |
161 | 170 | params = map(identity, params)
|
| 171 | + |
| 172 | + verbose && @info "Running AD on $(model.f) with $(adtype)\n" |
162 | 173 | verbose && println(" params : $(params)")
|
163 | 174 | ldf = LogDensityFunction(model, varinfo; adtype=adtype)
|
164 | 175 |
|
|
0 commit comments