@@ -137,14 +137,25 @@ Check that the element types in `vi` are compatible with the ADType of `context`
137137Throw an `IncompatibleADTypeError` if an incompatible element type is encountered.
138138"""
139139function check_adtype (context:: ADTypeCheckContext , vi:: DynamicPPL.AbstractVarInfo )
140+ # If we are using InitFromPrior or InitFromUniform to generate new values,
141+ # then the parameter type will be Any, so we should skip the check.
142+ lc = DynamicPPL. leafcontext (context)
143+ if lc isa DynamicPPL. InitContext{
144+ <: Any ,<: Union{DynamicPPL.InitFromPrior,DynamicPPL.InitFromUniform}
145+ }
146+ return nothing
147+ end
148+ # Note that `get_param_eltype` will return `Any` with e.g. InitFromPrior or
149+ # InitFromUniform, so this will fail. But on the bright side, you would never _really_
150+ # use AD with those strategies, so that's fine. The cases where you do want to
151+ # use this are DefaultContext (i.e., old, slow, LogDensityFunction) and
152+ # InitFromParams{<:VectorWithRanges} (i.e., new, fast, LogDensityFunction), and
153+ # both of those give you sensible results for `get_param_eltype`.
154+ param_eltype = DynamicPPL. get_param_eltype (vi, context)
140155 valids = valid_eltypes (context)
141- for val in vi[:]
142- valtype = typeof (val)
143- if ! any (valtype .< : valids)
144- throw (IncompatibleADTypeError (valtype, adtype (context)))
145- end
156+ if ! (any (param_eltype .< : valids))
157+ throw (IncompatibleADTypeError (param_eltype, adtype (context)))
146158 end
147- return nothing
148159end
149160
150161# A bunch of tilde_assume/tilde_observe methods that just call the same method on the child
@@ -199,10 +210,10 @@ end
199210 @testset " Expected: $expected_adtype , Actual: $actual_adtype " begin
200211 if actual_adtype == expected_adtype
201212 # Check that this does not throw an error.
202- sample (contextualised_tm, sampler, 2 )
213+ sample (contextualised_tm, sampler, 2 ; check_model = false )
203214 else
204215 @test_throws AbstractWrongADBackendError sample (
205- contextualised_tm, sampler, 2
216+ contextualised_tm, sampler, 2 ; check_model = false
206217 )
207218 end
208219 end
0 commit comments