Skip to content

Commit 0cc3e45

Browse files
committed
Update ADTypeCheckContext
1 parent 6f11379 commit 0cc3e45

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

test/ad.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,17 @@ Check that the element types in `vi` are compatible with the ADType of `context`
137137
Throw an `IncompatibleADTypeError` if an incompatible element type is encountered.
138138
"""
139139
function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo)
140+
# Note that `get_param_eltype` will return `Any` with e.g. InitFromPrior or
141+
# InitFromUniform, so this will fail. But on the bright side, you would never _really_
142+
# use AD with those strategies, so that's fine. The cases where you do want to
143+
# use this are DefaultContext (i.e., old, slow, LogDensityFunction) and
144+
# InitFromParams{<:VectorWithRanges} (i.e., new, fast, LogDensityFunction), and
145+
# both of those give you sensible results for `get_param_eltype`.
146+
param_eltype = DynamicPPL.get_param_eltype(vi, context)
140147
valids = valid_eltypes(context)
141-
for val in vi[:]
142-
valtype = typeof(val)
143-
if !any(valtype .<: valids)
144-
throw(IncompatibleADTypeError(valtype, adtype(context)))
145-
end
148+
if !(param_eltype .<: valids)
149+
throw(IncompatibleADTypeError(valtype, adtype(context)))
146150
end
147-
return nothing
148151
end
149152

150153
# A bunch of tilde_assume/tilde_observe methods that just call the same method on the child

0 commit comments

Comments
 (0)