Skip to content

Commit ce9c0a0

Browse files
committed
Update ADTypeCheckContext
1 parent 6f11379 commit ce9c0a0

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

test/ad.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,17 @@ Check that the element types in `vi` are compatible with the ADType of `context`
137137
Throw an `IncompatibleADTypeError` if an incompatible element type is encountered.
138138
"""
139139
function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo)
140+
# Note that `get_param_eltype` will return `Any` with e.g. InitFromPrior or
141+
# InitFromUniform, so this will fail. But on the bright side, you would never _really_
142+
# use AD with those strategies, so that's fine. The cases where you do want to
143+
# use this are DefaultContext (i.e., old, slow, LogDensityFunction) and
144+
# InitFromParams{<:VectorWithRanges} (i.e., new, fast, LogDensityFunction), and
145+
# both of those give you sensible results for `get_param_eltype`.
146+
param_eltype = DynamicPPL.get_param_eltype(vi, context)
140147
valids = valid_eltypes(context)
141-
for val in vi[:]
142-
valtype = typeof(val)
143-
if !any(valtype .<: valids)
144-
throw(IncompatibleADTypeError(valtype, adtype(context)))
145-
end
148+
if !(param_eltype .<: valids)
149+
throw(IncompatibleADTypeError(valtype, adtype(context)))
146150
end
147-
return nothing
148151
end
149152

150153
# A bunch of tilde_assume/tilde_observe methods that just call the same method on the child
@@ -176,7 +179,11 @@ end
176179
"""
177180
All the ADTypes on which we want to run the tests.
178181
"""
179-
ADTYPES = [AutoForwardDiff(), AutoReverseDiff(; compile=false)]
182+
adtypes = (
183+
AutoForwardDiff(),
184+
AutoReverseDiff(),
185+
# Don't need to test Mooncake as it doesn't use tracer types
186+
)
180187
if INCLUDE_MOONCAKE
181188
push!(ADTYPES, AutoMooncake(; config=nothing))
182189
end

0 commit comments

Comments
 (0)