-
Notifications
You must be signed in to change notification settings - Fork 37
Fast Log Density Function #1113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 5 commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
7cddac7
Fast Log Density Function
penelopeysm 5ed4295
Make it work with AD
penelopeysm e199520
Optimise performance for identity VarNames
penelopeysm 4cefaca
Mark `get_range_and_linked` as having zero derivative
penelopeysm 6dfd106
Update comment
penelopeysm 41ee7f3
make AD testing / benchmarking use FastLDF
penelopeysm 22e32a6
Fix tests
penelopeysm 79cc128
Optimise away `make_evaluate_args_and_kwargs`
penelopeysm f7c6a78
const func annotation
penelopeysm b1a7650
Disable benchmarks on non-typed-Metadata-VarInfo
penelopeysm e60873a
Fix `_evaluate!!` correctly to handle submodels
penelopeysm fa0664e
Actually fix submodel evaluate
penelopeysm 09a1fbb
Document thoroughly and organise code
penelopeysm 7306ba4
Support more VarInfos, make it thread-safe (?)
penelopeysm 53bccc1
fix bug in parsing ranges from metadata/VNV
penelopeysm 30b9247
Fix get_param_eltype for TSVI
penelopeysm 316937a
Disable Enzyme benchmark
penelopeysm 075cee8
Don't override _evaluate!!, that breaks ForwardDiff (sometimes)
penelopeysm 5f5a92c
Move FastLDF to experimental for now
penelopeysm 0716de5
Fix imports, add tests, etc
penelopeysm cd2461e
More test fixes
penelopeysm 1b8b873
Fix imports / tests
penelopeysm ff5680d
Remove AbstractFastEvalContext
penelopeysm 500d5ac
Changelog and patch bump
penelopeysm e560c30
Add correctness tests, fix imports
penelopeysm 22e9dbe
Merge branch 'main' into py/fastldf
penelopeysm c3bdcd0
Merge branch 'main' into py/fastldf
penelopeysm 86d8a73
Concretise parameter vector in tests
penelopeysm 4ec0c72
Merge branch 'main' into py/fastldf
penelopeysm 4b324b0
Add zero-allocation tests
penelopeysm c55171b
Add Chairmarks as test dep
penelopeysm 77deae9
Disable allocations tests on multi-threaded
penelopeysm 8715446
Fast InitContext (#1125)
penelopeysm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,16 +1,15 @@ | ||
| module DynamicPPLEnzymeCoreExt | ||
|
|
||
| if isdefined(Base, :get_extension) | ||
| using DynamicPPL: DynamicPPL | ||
| using EnzymeCore | ||
| else | ||
| using ..DynamicPPL: DynamicPPL | ||
| using ..EnzymeCore | ||
| end | ||
| using DynamicPPL: DynamicPPL | ||
| using EnzymeCore | ||
|
|
||
| # Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme | ||
| # only checks whether such a method exists, and never runs it. | ||
| @inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) = | ||
| nothing | ||
| # Likewise for get_range_and_linked. | ||
| @inline EnzymeCore.EnzymeRules.inactive( | ||
| ::typeof(DynamicPPL.get_range_and_linked), args... | ||
| ) = nothing | ||
|
|
||
| end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,10 @@ | ||
| module DynamicPPLMooncakeExt | ||
|
|
||
| using DynamicPPL: DynamicPPL, is_transformed | ||
| using DynamicPPL: DynamicPPL, is_transformed, get_range_and_linked | ||
| using Mooncake: Mooncake | ||
|
|
||
| # This is purely an optimisation. | ||
| Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg} | ||
| Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(get_range_and_linked),Vararg} | ||
|
|
||
| end # module |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,176 @@ | ||
| struct OnlyAccsVarInfo{Accs<:AccumulatorTuple} <: AbstractVarInfo | ||
| accs::Accs | ||
| end | ||
| DynamicPPL.getaccs(vi::OnlyAccsVarInfo) = vi.accs | ||
| DynamicPPL.maybe_invlink_before_eval!!(vi::OnlyAccsVarInfo, ::Model) = vi | ||
| DynamicPPL.setaccs!!(::OnlyAccsVarInfo, accs::AccumulatorTuple) = OnlyAccsVarInfo(accs) | ||
|
|
||
| struct RangeAndLinked | ||
| # indices that the variable corresponds to in the vectorised parameter | ||
| range::UnitRange{Int} | ||
| # whether it's linked | ||
| is_linked::Bool | ||
| end | ||
|
|
||
| struct FastLDFContext{N<:NamedTuple,T<:AbstractVector{<:Real}} <: AbstractContext | ||
| # The ranges of identity VarNames are stored in a NamedTuple for improved performance | ||
| # (it's around 1.5x faster). | ||
| iden_varname_ranges::N | ||
| # This Dict stores the ranges for all other VarNames | ||
| varname_ranges::Dict{VarName,RangeAndLinked} | ||
| # The full parameter vector which we index into to get variable values | ||
| params::T | ||
| end | ||
| DynamicPPL.NodeTrait(::FastLDFContext) = IsLeaf() | ||
| function get_range_and_linked( | ||
| ctx::FastLDFContext, ::VarName{sym,typeof(identity)} | ||
| ) where {sym} | ||
| return ctx.iden_varname_ranges[sym] | ||
| end | ||
| function get_range_and_linked(ctx::FastLDFContext, vn::VarName) | ||
| return ctx.varname_ranges[vn] | ||
| end | ||
|
|
||
| function tilde_assume!!( | ||
| ctx::FastLDFContext, right::Distribution, vn::VarName, vi::OnlyAccsVarInfo | ||
| ) | ||
| # Don't need to read the data from the varinfo at all since it's | ||
| # all inside the context. | ||
| range_and_linked = get_range_and_linked(ctx, vn) | ||
| y = @view ctx.params[range_and_linked.range] | ||
| f = if range_and_linked.is_linked | ||
| from_linked_vec_transform(right) | ||
| else | ||
| from_vec_transform(right) | ||
| end | ||
| x, inv_logjac = with_logabsdet_jacobian(f, y) | ||
| vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right) | ||
| return x, vi | ||
| end | ||
|
|
||
| function tilde_observe!!( | ||
| ::FastLDFContext, | ||
| right::Distribution, | ||
| left, | ||
| vn::Union{VarName,Nothing}, | ||
| vi::OnlyAccsVarInfo, | ||
| ) | ||
| # This is the same as for DefaultContext | ||
| vi = accumulate_observe!!(vi, right, left, vn) | ||
| return left, vi | ||
| end | ||
|
|
||
| struct FastLDF{ | ||
| M<:Model, | ||
| F<:Function, | ||
| N<:NamedTuple, | ||
| AD<:Union{ADTypes.AbstractADType,Nothing}, | ||
| ADP<:Union{Nothing,DI.GradientPrep}, | ||
| } | ||
| _model::M | ||
| _getlogdensity::F | ||
| # See FastLDFContext for explanation of these two fields | ||
| _iden_varname_ranges::N | ||
| _varname_ranges::Dict{VarName,RangeAndLinked} | ||
| _adtype::AD | ||
| _adprep::ADP | ||
|
|
||
| function FastLDF( | ||
| model::Model, | ||
| getlogdensity::Function, | ||
| # This only works with typed Metadata-varinfo. | ||
| # Obviously, this can be generalised later. | ||
| varinfo::VarInfo{<:NamedTuple{syms}}; | ||
| adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, | ||
| ) where {syms} | ||
| # Figure out which variable corresponds to which index, and | ||
| # which variables are linked. | ||
| all_iden_ranges = NamedTuple() | ||
| all_ranges = Dict{VarName,RangeAndLinked}() | ||
| offset = 1 | ||
| for sym in syms | ||
| md = varinfo.metadata[sym] | ||
| for (vn, idx) in md.idcs | ||
| len = length(md.ranges[idx]) | ||
| is_linked = md.is_transformed[idx] | ||
| range = offset:(offset + len - 1) | ||
| if AbstractPPL.getoptic(vn) === identity | ||
| all_iden_ranges = merge( | ||
| all_iden_ranges, | ||
| NamedTuple(( | ||
| AbstractPPL.getsym(vn) => RangeAndLinked(range, is_linked), | ||
| )), | ||
| ) | ||
| else | ||
| all_ranges[vn] = RangeAndLinked(range, is_linked) | ||
| end | ||
| offset += len | ||
| end | ||
| end | ||
| # Do AD prep if needed | ||
| prep = if adtype === nothing | ||
| nothing | ||
| else | ||
| # Make backend-specific tweaks to the adtype | ||
| adtype = tweak_adtype(adtype, model, varinfo) | ||
| x = [val for val in varinfo[:]] | ||
| DI.prepare_gradient( | ||
| FastLogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), | ||
| adtype, | ||
| x, | ||
| ) | ||
| end | ||
|
|
||
| return new{ | ||
| typeof(model), | ||
| typeof(getlogdensity), | ||
| typeof(all_iden_ranges), | ||
| typeof(adtype), | ||
| typeof(prep), | ||
| }( | ||
| model, getlogdensity, all_iden_ranges, all_ranges, adtype, prep | ||
| ) | ||
| end | ||
| end | ||
|
|
||
| struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} | ||
| _model::M | ||
| _getlogdensity::F | ||
| _iden_varname_ranges::N | ||
| _varname_ranges::Dict{VarName,RangeAndLinked} | ||
| end | ||
| function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) | ||
| ctx = FastLDFContext(f._iden_varname_ranges, f._varname_ranges, params) | ||
| model = DynamicPPL.setleafcontext(f._model, ctx) | ||
| # This can obviously also be optimised for the case where not | ||
| # all accumulators are needed. | ||
| accs = AccumulatorTuple(( | ||
penelopeysm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| LogPriorAccumulator(), LogLikelihoodAccumulator(), LogJacobianAccumulator() | ||
| )) | ||
| _, vi = DynamicPPL._evaluate!!(model, OnlyAccsVarInfo(accs)) | ||
| return f._getlogdensity(vi) | ||
| end | ||
|
|
||
| function LogDensityProblems.logdensity(fldf::FastLDF, params::AbstractVector{<:Real}) | ||
| return FastLogDensityAt( | ||
| fldf._model, fldf._getlogdensity, fldf._iden_varname_ranges, fldf._varname_ranges | ||
| )( | ||
| params | ||
| ) | ||
| end | ||
|
|
||
| function LogDensityProblems.logdensity_and_gradient( | ||
| fldf::FastLDF, params::AbstractVector{<:Real} | ||
| ) | ||
| return DI.value_and_gradient( | ||
| FastLogDensityAt( | ||
| fldf._model, | ||
| fldf._getlogdensity, | ||
| fldf._iden_varname_ranges, | ||
| fldf._varname_ranges, | ||
| ), | ||
| fldf._adprep, | ||
| fldf._adtype, | ||
| params, | ||
| ) | ||
| end | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It turns out that a substantial amount of performance gains were in this innocuous line 😄 Just optimising the implementation of this function for normal VarInfo, plus sticking some
@views in appropriate places, gets the old LogDensityFunction to within ~ 50% of FastLDF's performance. See #1115.In some ways the simplification here is a result of the current approach: because the linked status is no longer stored in a VarInfo object but rather a separate struct, there is no longer any need to care about linking / invlinking a VarInfo.
But overall I think this means that the performance gains from the current approach are much more comprehensible.