Skip to content

Commit 13e1f15

Browse files
penelopeysmsunxd3
andcommitted
Improve docstring
Co-authored-by: Xianda Sun <[email protected]>
1 parent 50c8598 commit 13e1f15

File tree

1 file changed

+36
-21
lines changed

1 file changed

+36
-21
lines changed

src/test_utils/ad.jl

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -78,42 +78,57 @@ Test the correctness and/or benchmark the AD backend `adtype` for the model
7878
`model`.
7979
8080
Whether to test and benchmark is controlled by the `test` and `benchmark`
81-
keyword arguments. By default, `test` is `true` and `benchmark` is `false.
81+
keyword arguments. By default, `test` is `true` and `benchmark` is `false`.
8282
8383
Returns an [`ADResult`](@ref) object, which contains the results of the
8484
test and/or benchmark.
8585
86-
This function is not as complicated as its signature makes it look. There are
87-
two things that must be provided:
86+
Note that to run AD successfully you will need to import the AD backend itself.
87+
For example, to test with `AutoReverseDiff()` you will need to run `import
88+
ReverseDiff`.
89+
90+
There are two positional arguments, which absolutely must be provided:
8891
8992
1. `model` - The model being tested.
9093
2. `adtype` - The AD backend being tested.
9194
9295
Everything else is optional, and can be categorised into several groups:
9396
9497
1. _How to specify the VarInfo._ DynamicPPL contains several different types of
95-
VarInfo objects which change the way model evaluation occurs. If you want to
96-
use a specific type of VarInfo, pass it as the `varinfo` argument. Otherwise,
97-
it will default to using a `TypedVarInfo` generated from the model.
98+
VarInfo objects which change the way model evaluation occurs. If you want to
99+
use a specific type of VarInfo, pass it as the `varinfo` argument.
100+
Otherwise, it will default to using a `TypedVarInfo` generated from the
101+
model.
98102
99103
2. _How to specify the parameters._ For maximum control over this, generate a
100-
vector of parameters yourself and pass this as the `params` argument. If you
101-
don't specify this, it will be taken from the contents of the VarInfo. Note
102-
that if the VarInfo is not specified (and thus automatically generated) the
103-
parameters in it will have been sampled from the prior of the model. If you
104-
want to seed the parameter generation, the easiest way is to pass a `rng`
105-
argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`).
104+
vector of parameters yourself and pass this as the `params` argument. If you
105+
don't specify this, it will be taken from the contents of the VarInfo.
106+
107+
Note that if the VarInfo is not specified (and thus automatically generated)
108+
the parameters in it will have been sampled from the prior of the model. If
109+
you want to seed the parameter generation, the easiest way is to pass a
110+
`rng` argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`).
111+
112+
Finally, note that these only reflect the parameters used for _evaluating_
113+
the gradient. If you also want to control the parameters used for
114+
_preparing_ the gradient, then you need to manually set these parameters in
115+
the VarInfo object, for example using `vi = DynamicPPL.unflatten(vi,
116+
prep_params)`. You could then evaluate the gradient at a different set of
117+
parameters using the `params` keyword argument.
106118
107119
3. _How to specify the results to compare against._ (Only if `test=true`.) Once
108-
logp and its gradient has been calculated with the specified `adtype`, it must
109-
be tested for correctness. This can be done either by specifying
110-
`reference_adtype`, in which case logp and its gradient will also be calculated
111-
with this reference in order to obtain the ground truth; or by using
112-
`expected_value_and_grad`, which is a tuple of (logp, gradient) that the
113-
calculated values must match. The latter is useful if you are testing multiple
114-
AD backends and want to avoid recalculating the ground truth multiple times.
115-
The default reference backend is ForwardDiff. If none of these parameters are
116-
specified, that will be used to calculate the ground truth.
120+
logp and its gradient has been calculated with the specified `adtype`, it
121+
must be tested for correctness.
122+
123+
This can be done either by specifying `reference_adtype`, in which case logp
124+
and its gradient will also be calculated with this reference in order to
125+
obtain the ground truth; or by using `expected_value_and_grad`, which is a
126+
tuple of `(logp, gradient)` that the calculated values must match. The
127+
latter is useful if you are testing multiple AD backends and want to avoid
128+
recalculating the ground truth multiple times.
129+
130+
The default reference backend is ForwardDiff. If none of these parameters are
131+
specified, ForwardDiff will be used to calculate the ground truth.
117132
118133
4. _How to specify the tolerances._ (Only if `test=true`.) The tolerances for
119134
the value and gradient can be set using `value_atol` and `grad_atol`. These

0 commit comments

Comments
 (0)