Skip to content

Commit 1544a40

Browse files
committed
Update ADTypeCheckContext
1 parent 6f11379 commit 1544a40

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

test/ad.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,25 @@ Check that the element types in `vi` are compatible with the ADType of `context`
137137
Throw an `IncompatibleADTypeError` if an incompatible element type is encountered.
138138
"""
139139
function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo)
140+
# If we are using InitFromPrior or InitFromUniform to generate new values,
141+
# then the parameter type will be Any, so we should skip the check.
142+
lc = DynamicPPL.leafcontext(context)
143+
if lc isa DynamicPPL.InitContext{
144+
<:Any,<:Union{DynamicPPL.InitFromPrior,DynamicPPL.InitFromUniform}
145+
}
146+
return nothing
147+
end
148+
# Note that `get_param_eltype` will return `Any` with e.g. InitFromPrior or
149+
# InitFromUniform, so this will fail. But on the bright side, you would never _really_
150+
# use AD with those strategies, so that's fine. The cases where you do want to
151+
# use this are DefaultContext (i.e., old, slow, LogDensityFunction) and
152+
# InitFromParams{<:VectorWithRanges} (i.e., new, fast, LogDensityFunction), and
153+
# both of those give you sensible results for `get_param_eltype`.
154+
param_eltype = DynamicPPL.get_param_eltype(vi, context)
140155
valids = valid_eltypes(context)
141-
for val in vi[:]
142-
valtype = typeof(val)
143-
if !any(valtype .<: valids)
144-
throw(IncompatibleADTypeError(valtype, adtype(context)))
145-
end
156+
if !(any(param_eltype .<: valids))
157+
throw(IncompatibleADTypeError(param_eltype, adtype(context)))
146158
end
147-
return nothing
148159
end
149160

150161
# A bunch of tilde_assume/tilde_observe methods that just call the same method on the child
@@ -199,10 +210,10 @@ end
199210
@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin
200211
if actual_adtype == expected_adtype
201212
# Check that this does not throw an error.
202-
sample(contextualised_tm, sampler, 2)
213+
sample(contextualised_tm, sampler, 2; check_model=false)
203214
else
204215
@test_throws AbstractWrongADBackendError sample(
205-
contextualised_tm, sampler, 2
216+
contextualised_tm, sampler, 2; check_model=false
206217
)
207218
end
208219
end

0 commit comments

Comments
 (0)