-
Notifications
You must be signed in to change notification settings - Fork 2
Closed
Description
That's why, for example, https://turinglang.org/ADTests/pr/ (built from #23) currently shows a lot of errors.
Currently minimised to:
module MWE
using DynamicPPL: DynamicPPL, VarInfo, @model
import Enzyme: Enzyme, set_runtime_activity, Forward, Reverse, Const
using Distributions: InverseGamma, MvNormal, product_distribution
using LinearAlgebra: Diagonal
@model function demo_assume_multivariate_observe(x = [1.5, 2.0])
s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
x ~ MvNormal([0.0, 0.0], Diagonal(s))
end
# Removing any of these three blocks causes it to run fine
struct PStruct
a
end
Base.getproperty(::Type{PStruct}, s::Symbol) = getfield(PStruct, s)
struct QStruct
a
end
Base.getproperty(::Type{QStruct}, s::Symbol) = getfield(QStruct, s)
struct RStruct
a
end
Base.getproperty(::Type{RStruct}, s::Symbol) = getfield(RStruct, s)
# =========================================
model = demo_assume_multivariate_observe()
varinfo = VarInfo(model)
params = varinfo[:]
function f(
x::AbstractVector, model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo
)
varinfo_new = DynamicPPL.unflatten(varinfo, x)
_, vi = DynamicPPL.evaluate!!(model, varinfo_new, DynamicPPL.DefaultContext())
return vi.logp[]
end
Enzyme.gradient(set_runtime_activity(Forward), f, params, Const(model), Const(varinfo))
end
Metadata
Metadata
Assignees
Labels
No labels