Skip to content

Clean up LogDensityFunctions interface code + setADtype #2473

@penelopeysm

Description

@penelopeysm

function LogDensityProblemsAD.ADgradient(ℓ::DynamicPPL.LogDensityFunction)
return LogDensityProblemsAD.ADgradient(getADType(DynamicPPL.getcontext(ℓ)), ℓ)
end
function LogDensityProblems.logdensity(
f::Turing.LogDensityFunction{<:AbstractVarInfo,<:Model,<:DynamicPPL.DefaultContext},
x::NamedTuple,
)
return DynamicPPL.logjoint(f.model, DynamicPPL.unflatten(f.varinfo, x))
end
# TODO: make a nicer `set_namedtuple!` and move these functions to DynamicPPL.
function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple)
set_namedtuple!(deepcopy(vi), θ)
return vi
end
function DynamicPPL.unflatten(vi::SimpleVarInfo, θ::NamedTuple)
return SimpleVarInfo(θ, vi.logp, vi.transformation)
end

This code looks suspiciously similar to existing code in DynamicPPL, and in any case is type piracy and should be moved there or deleted as appropriate.

Unsure why Aqua doesn't flag this.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions