-
Notifications
You must be signed in to change notification settings - Fork 37
Description
Looking at #1141: given that the trivial model
@model trivial() = x ~ Normal()only has a distribution for which the link function is identity, why does that still produce one allocation when running with linked parameters? The answer can be found by looking at
julia> DynamicPPL.from_linked_vec_transform(Normal())
DynamicPPL.UnwrapSingletonTransform{Tuple{}}(()) ∘ (identity ∘ DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}((1,), ()))which in turn suggests a potential optimisation. Here's a sneak peek of the optimisation:
using Bijectors
struct Only end
struct NotOnly end
DynamicPPL.from_linked_vec_transform(::Normal) = Only()
DynamicPPL.from_vec_transform(::Normal) = Only()
Bijectors.with_logabsdet_jacobian(::Only, x::AbstractVector) = (x[], 0.0)
Bijectors.isinvertible(::Only) = true
Bijectors.inverse(::Only) = NotOnly()
Bijectors.with_logabsdet_jacobian(::NotOnly, y) = ([y], 0.0)
struct ExpOnly end
struct LogVec end
# The next line is only valid when the lower bound of the truncated distribution is zero.
# We could make this more correct by introducing a 'truncated at zero' wrapper for
# distributions.
DynamicPPL.from_linked_vec_transform(::Truncated{<:Cauchy}) = ExpOnly()
DynamicPPL.from_vec_transform(::Truncated{<:Cauchy}) = Only()
function Bijectors.with_logabsdet_jacobian(::ExpOnly, y::AbstractVector)
yi = y[]
x = exp(yi)
return (x, yi)
end
Bijectors.isinvertible(::ExpOnly) = true
Bijectors.inverse(::ExpOnly) = LogVec()
function Bijectors.with_logabsdet_jacobian(::LogVec, x)
y = log(x)
return ([y], -y)
end(In passing: note also that making from_linked_vec_transform and from_vec_transform return the same thing will also help with type stability for the most general cases where the link status is not known ahead of time.)
And the results:
# trivial unlinked after #1141 after #1141 + with the above patch
eval ---- 10.649 ns 5.582 ns
grad (FD) ---- 39.345 ns (3 allocs: 96 bytes) 50.362 ns (3 allocs: 96 bytes)
grad (RD) ---- 2.629 μs (44 allocs: 1.500 KiB) 2.995 μs (44 allocs: 1.500 KiB)
grad (MC) ---- 269.660 ns (4 allocs: 192 bytes) 228.161 ns (4 allocs: 192 bytes)
grad (EN) ---- 100.539 ns (2 allocs: 64 bytes) 69.643 ns (2 allocs: 64 bytes)
# eight-schools unlinked after #1141 after #1141 + with the above patch
eval ---- 170.667 ns (4 allocs: 256 bytes) 152.043 ns (4 allocs: 256 bytes)
grad (FD) ---- 775.211 ns (11 allocs: 2.594 KiB) 742.513 ns (11 allocs: 2.594 KiB)
grad (RD) ---- 35.083 μs (555 allocs: 20.297 KiB) 40.667 μs (555 allocs: 20.297 KiB)
grad (MC) ---- 1.264 μs (12 allocs: 784 bytes) 1.193 μs (12 allocs: 784 bytes)
grad (EN) ---- 630.319 ns (13 allocs: 832 bytes) 584.157 ns (13 allocs: 832 bytes)
# trivial linked after #1141 after #1141 + with the above patch
eval ---- 14.664 ns (1 allocs: 32 bytes) 5.586 ns
grad (FD) ---- 43.595 ns (4 allocs: 144 bytes) 50.418 ns (3 allocs: 96 bytes)
grad (RD) ---- 2.718 μs (52 allocs: 1.781 KiB) 3.044 μs (44 allocs: 1.500 KiB)
grad (MC) ---- 319.596 ns (6 allocs: 256 bytes) 297.446 ns (4 allocs: 192 bytes)
grad (EN) ---- 172.721 ns (6 allocs: 208 bytes) 69.511 ns (2 allocs: 64 bytes)
# eight-schools linked after #1141 after #1141 + with the above patch
eval ---- 241.115 ns (7 allocs: 352 bytes) 157.659 ns (4 allocs: 256 bytes)
grad (FD) ---- 886.719 ns (13 allocs: 2.812 KiB) 773.487 ns (11 allocs: 2.594 KiB)
grad (RD) ---- 38.250 μs (593 allocs: 21.641 KiB) 41.708 μs (567 allocs: 20.688 KiB)
grad (MC) ---- 1.511 μs (18 allocs: 976 bytes) 1.183 μs (12 allocs: 784 bytes)
grad (EN) ---- 998.600 ns (33 allocs: 1.469 KiB) 602.898 ns (13 allocs: 832 bytes)
I think this is essentially the final missing 'feature' in SimpleVarInfo that isn't yet been ported over to FastLDF. But, here, we don't need to make it a FastLDF-exclusive feature (unlike the way that SVI feature is SVI-exclusive). Better transforms make everything faster, not just FastLDF.