@@ -78,42 +78,57 @@ Test the correctness and/or benchmark the AD backend `adtype` for the model
78
78
`model`.
79
79
80
80
Whether to test and benchmark is controlled by the `test` and `benchmark`
81
- keyword arguments. By default, `test` is `true` and `benchmark` is `false.
81
+ keyword arguments. By default, `test` is `true` and `benchmark` is `false` .
82
82
83
83
Returns an [`ADResult`](@ref) object, which contains the results of the
84
84
test and/or benchmark.
85
85
86
- This function is not as complicated as its signature makes it look. There are
87
- two things that must be provided:
86
+ Note that to run AD successfully you will need to import the AD backend itself.
87
+ For example, to test with `AutoReverseDiff()` you will need to run `import
88
+ ReverseDiff`.
89
+
90
+ There are two positional arguments, which absolutely must be provided:
88
91
89
92
1. `model` - The model being tested.
90
93
2. `adtype` - The AD backend being tested.
91
94
92
95
Everything else is optional, and can be categorised into several groups:
93
96
94
97
1. _How to specify the VarInfo._ DynamicPPL contains several different types of
95
- VarInfo objects which change the way model evaluation occurs. If you want to
96
- use a specific type of VarInfo, pass it as the `varinfo` argument. Otherwise,
97
- it will default to using a `TypedVarInfo` generated from the model.
98
+ VarInfo objects which change the way model evaluation occurs. If you want to
99
+ use a specific type of VarInfo, pass it as the `varinfo` argument.
100
+ Otherwise, it will default to using a `TypedVarInfo` generated from the
101
+ model.
98
102
99
103
2. _How to specify the parameters._ For maximum control over this, generate a
100
- vector of parameters yourself and pass this as the `params` argument. If you
101
- don't specify this, it will be taken from the contents of the VarInfo. Note
102
- that if the VarInfo is not specified (and thus automatically generated) the
103
- parameters in it will have been sampled from the prior of the model. If you
104
- want to seed the parameter generation, the easiest way is to pass a `rng`
105
- argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`).
104
+ vector of parameters yourself and pass this as the `params` argument. If you
105
+ don't specify this, it will be taken from the contents of the VarInfo.
106
+
107
+ Note that if the VarInfo is not specified (and thus automatically generated)
108
+ the parameters in it will have been sampled from the prior of the model. If
109
+ you want to seed the parameter generation, the easiest way is to pass a
110
+ `rng` argument to the VarInfo constructor (i.e. do `VarInfo(rng, model)`).
111
+
112
+ Finally, note that these only reflect the parameters used for _evaluating_
113
+ the gradient. If you also want to control the parameters used for
114
+ _preparing_ the gradient, then you need to manually set these parameters in
115
+ the VarInfo object, for example using `vi = DynamicPPL.unflatten(vi,
116
+ prep_params)`. You could then evaluate the gradient at a different set of
117
+ parameters using the `params` keyword argument.
106
118
107
119
3. _How to specify the results to compare against._ (Only if `test=true`.) Once
108
- logp and its gradient has been calculated with the specified `adtype`, it must
109
- be tested for correctness. This can be done either by specifying
110
- `reference_adtype`, in which case logp and its gradient will also be calculated
111
- with this reference in order to obtain the ground truth; or by using
112
- `expected_value_and_grad`, which is a tuple of (logp, gradient) that the
113
- calculated values must match. The latter is useful if you are testing multiple
114
- AD backends and want to avoid recalculating the ground truth multiple times.
115
- The default reference backend is ForwardDiff. If none of these parameters are
116
- specified, that will be used to calculate the ground truth.
120
+ logp and its gradient has been calculated with the specified `adtype`, it
121
+ must be tested for correctness.
122
+
123
+ This can be done either by specifying `reference_adtype`, in which case logp
124
+ and its gradient will also be calculated with this reference in order to
125
+ obtain the ground truth; or by using `expected_value_and_grad`, which is a
126
+ tuple of `(logp, gradient)` that the calculated values must match. The
127
+ latter is useful if you are testing multiple AD backends and want to avoid
128
+ recalculating the ground truth multiple times.
129
+
130
+ The default reference backend is ForwardDiff. If none of these parameters are
131
+ specified, ForwardDiff will be used to calculate the ground truth.
117
132
118
133
4. _How to specify the tolerances._ (Only if `test=true`.) The tolerances for
119
134
the value and gradient can be set using `value_atol` and `grad_atol`. These
0 commit comments