Skip to content

Transforms could be optimised #1142

@penelopeysm

Description

@penelopeysm

Looking at #1141: given that the trivial model

@model trivial() = x ~ Normal()

only has a distribution for which the link function is identity, why does that still produce one allocation when running with linked parameters? The answer can be found by looking at

julia> DynamicPPL.from_linked_vec_transform(Normal())
DynamicPPL.UnwrapSingletonTransform{Tuple{}}(())  (identity  DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}((1,), ()))

which in turn suggests a potential optimisation. Here's a sneak peek of the optimisation:

using Bijectors

struct Only end
struct NotOnly end
DynamicPPL.from_linked_vec_transform(::Normal) = Only()
DynamicPPL.from_vec_transform(::Normal) = Only()
Bijectors.with_logabsdet_jacobian(::Only, x::AbstractVector) = (x[], 0.0)
Bijectors.isinvertible(::Only) = true
Bijectors.inverse(::Only) = NotOnly()
Bijectors.with_logabsdet_jacobian(::NotOnly, y) = ([y], 0.0)
struct ExpOnly end
struct LogVec end
# The next line is only valid when the lower bound of the truncated distribution is zero.
# We could make this more correct by introducing a 'truncated at zero' wrapper for
# distributions.
DynamicPPL.from_linked_vec_transform(::Truncated{<:Cauchy}) = ExpOnly()
DynamicPPL.from_vec_transform(::Truncated{<:Cauchy}) = Only()
function Bijectors.with_logabsdet_jacobian(::ExpOnly, y::AbstractVector)
    yi = y[]
    x = exp(yi)
    return (x, yi)
end
Bijectors.isinvertible(::ExpOnly) = true
Bijectors.inverse(::ExpOnly) = LogVec()
function Bijectors.with_logabsdet_jacobian(::LogVec, x)
    y = log(x)
    return ([y], -y)
end

(In passing: note also that making from_linked_vec_transform and from_vec_transform return the same thing will also help with type stability for the most general cases where the link status is not known ahead of time.)

And the results:

# trivial unlinked         after #1141                          after #1141 + with the above patch
eval      ----             10.649 ns                            5.582 ns
grad (FD) ----             39.345 ns (3 allocs: 96 bytes)       50.362 ns (3 allocs: 96 bytes)
grad (RD) ----             2.629 μs (44 allocs: 1.500 KiB)      2.995 μs (44 allocs: 1.500 KiB)
grad (MC) ----             269.660 ns (4 allocs: 192 bytes)     228.161 ns (4 allocs: 192 bytes)
grad (EN) ----             100.539 ns (2 allocs: 64 bytes)      69.643 ns (2 allocs: 64 bytes)

# eight-schools unlinked   after #1141                          after #1141 + with the above patch
eval      ----             170.667 ns (4 allocs: 256 bytes)     152.043 ns (4 allocs: 256 bytes)
grad (FD) ----             775.211 ns (11 allocs: 2.594 KiB)    742.513 ns (11 allocs: 2.594 KiB)
grad (RD) ----             35.083 μs (555 allocs: 20.297 KiB)   40.667 μs (555 allocs: 20.297 KiB)
grad (MC) ----             1.264 μs (12 allocs: 784 bytes)      1.193 μs (12 allocs: 784 bytes)
grad (EN) ----             630.319 ns (13 allocs: 832 bytes)    584.157 ns (13 allocs: 832 bytes)

# trivial linked           after #1141                          after #1141 + with the above patch
eval      ----             14.664 ns (1 allocs: 32 bytes)       5.586 ns
grad (FD) ----             43.595 ns (4 allocs: 144 bytes)      50.418 ns (3 allocs: 96 bytes)
grad (RD) ----             2.718 μs (52 allocs: 1.781 KiB)      3.044 μs (44 allocs: 1.500 KiB)
grad (MC) ----             319.596 ns (6 allocs: 256 bytes)     297.446 ns (4 allocs: 192 bytes)
grad (EN) ----             172.721 ns (6 allocs: 208 bytes)     69.511 ns (2 allocs: 64 bytes)

# eight-schools linked     after #1141                          after #1141 + with the above patch
eval      ----             241.115 ns (7 allocs: 352 bytes)     157.659 ns (4 allocs: 256 bytes)
grad (FD) ----             886.719 ns (13 allocs: 2.812 KiB)    773.487 ns (11 allocs: 2.594 KiB)
grad (RD) ----             38.250 μs (593 allocs: 21.641 KiB)   41.708 μs (567 allocs: 20.688 KiB)
grad (MC) ----             1.511 μs (18 allocs: 976 bytes)      1.183 μs (12 allocs: 784 bytes)
grad (EN) ----             998.600 ns (33 allocs: 1.469 KiB)    602.898 ns (13 allocs: 832 bytes)

I think this is essentially the final missing 'feature' in SimpleVarInfo that isn't yet been ported over to FastLDF. But, here, we don't need to make it a FastLDF-exclusive feature (unlike the way that SVI feature is SVI-exclusive). Better transforms make everything faster, not just FastLDF.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions