Skip to content

Commit 976885c

Browse files
committed
Fix docstring + use AbstractFloat
1 parent 9bf0593 commit 976885c

File tree

1 file changed

+13
-16
lines changed

1 file changed

+13
-16
lines changed

src/test_utils/ad.jl

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,28 +21,28 @@ it's the default AD backend used in Turing.jl.
2121
const REFERENCE_ADTYPE = AutoForwardDiff()
2222

2323
"""
24-
ADIncorrectException{T<:Real}
24+
ADIncorrectException{T<:AbstractFloat}
2525
2626
Exception thrown when an AD backend returns an incorrect value or gradient.
2727
2828
The type parameter `T` is the numeric type of the value and gradient.
2929
"""
30-
struct ADIncorrectException{T<:Real} <: Exception
30+
struct ADIncorrectException{T<:AbstractFloat} <: Exception
3131
value_expected::T
3232
value_actual::T
3333
grad_expected::Vector{T}
3434
grad_actual::Vector{T}
3535
end
3636

3737
"""
38-
ADResult{Tparams<:Real,Tresult<:Real}
38+
ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
3939
4040
Data structure to store the results of the AD correctness test.
4141
4242
The type parameter `Tparams` is the numeric type of the parameters passed in;
4343
`Tresult` is the type of the value and the gradient.
4444
"""
45-
struct ADResult{Tparams<:Real,Tresult<:Real}
45+
struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
4646
"The DynamicPPL model that was tested"
4747
model::Model
4848
"The VarInfo that was used"
@@ -76,9 +76,9 @@ end
7676
value_atol=1e-6,
7777
grad_atol=1e-6,
7878
varinfo::AbstractVarInfo=link(VarInfo(model), model),
79-
params::Union{Nothing,Vector{<:Real}}=nothing,
79+
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
8080
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
81-
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
81+
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
8282
verbose=true,
8383
)::ADResult
8484
@@ -108,12 +108,9 @@ Everything else is optional, and can be categorised into several groups:
108108
DynamicPPL contains several different types of VarInfo objects which change
109109
the way model evaluation occurs. If you want to use a specific type of
110110
VarInfo, pass it as the `varinfo` argument. Otherwise, it will default to
111-
using a `TypedVarInfo` generated from the model.
112-
113-
It will also perform _linking_, that is, the parameters in the VarInfo will
114-
be transformed to unconstrained Euclidean space if they aren't already in
115-
that space. Note that the act of linking may change the length of the
116-
parameters. To disable linking, set `linked=false`.
111+
using a linked `TypedVarInfo` generated from the model. Here, _linked_
112+
means that the parameters in the VarInfo have been transformed to
113+
unconstrained Euclidean space if they aren't already in that space.
117114
118115
2. _How to specify the parameters._
119116
@@ -172,12 +169,12 @@ function run_ad(
172169
adtype::AbstractADType;
173170
test::Bool=true,
174171
benchmark::Bool=false,
175-
value_atol::Real=1e-6,
176-
grad_atol::Real=1e-6,
172+
value_atol::AbstractFloat=1e-6,
173+
grad_atol::AbstractFloat=1e-6,
177174
varinfo::AbstractVarInfo=link(VarInfo(model), model),
178-
params::Union{Nothing,Vector{<:Real}}=nothing,
175+
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
179176
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
180-
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
177+
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
181178
verbose=true,
182179
)::ADResult
183180
if isnothing(params)

0 commit comments

Comments
 (0)