-
Notifications
You must be signed in to change notification settings - Fork 430
Improve performance of logpdf
for DiagNormal
and IsoNormal
#1991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #1991 +/- ##
==========================================
+ Coverage 86.28% 86.30% +0.02%
==========================================
Files 146 146
Lines 8787 8801 +14
==========================================
+ Hits 7582 7596 +14
Misses 1205 1205 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we could just use
sqmahal(d::MvNormal, x::AbstractVector) = invquad(d.Σ, LazyArrays.@~(x .- d.μ))
as LazyArrays seems to optimize FillArrays broacasts as well, so we might be able to keep the optimizations for Zeros
means.
@@ -261,6 +261,25 @@ logdetcov(d::MvNormal) = logdet(d.Σ) | |||
|
|||
sqmahal(d::MvNormal, x::AbstractVector) = invquad(d.Σ, x .- d.μ) | |||
|
|||
function sqmahal(d::DiagNormal, x::AbstractVector) | |||
# Faster than above as this avoids calculating (x .- d.µ) | |||
T = promote_type(partype(d), eltype(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not correct eg if the parameters and x
are arrays of integers. One could either compute the first element outside of the loop or possibly use a functional approach (e.g. with mapreduce
and Base.Broadcast.broadcasted(...)
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The functional approach probably also works better with GPU arrays.
end | ||
|
||
function sqmahal(d::IsoNormal, x::AbstractVector) | ||
T = promote_type(partype(d), eltype(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very artificial, but this is also not guaranteed to be correct e.g. for Bool
ean values.
for i in eachindex(x) | ||
@inbounds sum += abs2(x[i] - d.μ[i]) / d.Σ[i, i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indices might not be compatible.
T = promote_type(partype(d), eltype(x)) | ||
sum = zero(T) | ||
for i in eachindex(x) | ||
@inbounds sum += abs2(x[i] - d.μ[i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indices might not be compatible
for i in eachindex(x) | ||
@inbounds sum += abs2(x[i] - d.μ[i]) | ||
end | ||
return sum / d.Σ[1, 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might not use one-based indexing.
This PR (sort of) closes #1989. I didn't attempt to make
rand()
faster because the performance drop there was smaller, but the changes here improve the performance oflogpdf
onDiagNormal
andIsoNormal
by about 2–3x.The offending code that made the previous
logpdf
slower was the calculation ofx .- d.µ
:Distributions.jl/src/multivariate/mvnormal.jl
Line 262 in abb151c
The new methods avoid this and are thus both faster and avoid allocations.