Skip to content

Commit 652bb38

Browse files
authored
Reduce allocations in all + any and remove isfinite type piracy (#388)
* Reduce allocations in `all` + `any` and remove `isfinite` type piracy * Fix `isfinite` check
1 parent 3394707 commit 652bb38

File tree

7 files changed

+14
-16
lines changed

7 files changed

+14
-16
lines changed

research/src/riemannian_hmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ function ∂H∂r(
356356
r::AbstractVecOrMat,
357357
)
358358
H = h.metric.G(θ)
359-
# if any(.!(isfinite.(H)))
359+
# if !all(isfinite, H)
360360
# println("θ: ", θ)
361361
# println("H: ", H)
362362
# end

research/src/riemannian_hmc_utility.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ function prepare_sample_target(hps, θ₀, ℓπ)
4343
fstabilize = H -> H + hps.λ * I
4444
Gfunc = x -> begin
4545
H = fstabilize(Hfunc(x)[3])
46-
any(.!(isfinite.(H))) ? diagm(ones(length(x))) : H
46+
all(isfinite, H) ? H : diagm(ones(length(x)))
4747
end
4848
_∂G∂θfunc = gen_∂G∂θ_fwd(Vfunc, θ₀; f = fstabilize) # size==(4, 2)
4949
∂G∂θfunc = x -> reshape_∂G∂θ(_∂G∂θfunc(x)) # size==(2, 2, 2)

research/tests/riemannian_hmc.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ using AdvancedHMC: neg_energy, energy
4343
hamiltonian = Hamiltonian(metric, kinetic, ℓπ, ∂ℓπ∂θ)
4444

4545
if hessmap isa SoftAbsMap || # only test kinetic energy for SoftAbsMap as that of IdentityMap can be non-PD
46-
all(iszero.(x)) # or for x==0 that I know it's PD
46+
all(iszero, x) # or for x==0 that I know it's PD
4747
@testset "Kinetic energy" begin
4848
Σ = hamiltonian.metric.map(hamiltonian.metric.G(x))
4949
@test neg_energy(hamiltonian, r, x) logpdf(MvNormal(zeros(D), Σ), r)

src/adaptation/stepsize.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ function adapt_stepsize!(
127127
@debug "Adapting step size..." new_ϵ = ϵ old_ϵ = da.state.ϵ
128128

129129
# TODO: we might want to remove this when all other numerical issues are correctly handelled
130-
if any(isnan.(ϵ)) || any(isinf.(ϵ))
130+
if !all(isfinite, ϵ)
131131
@warn "Incorrect ϵ = ; ϵ_previous = $(da.state.ϵ) is used instead."
132132
# FIXME: this revert is buggy for batch mode
133133
@unpack m, ϵ, x_bar, H_bar = state

src/hamiltonian.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,15 @@ struct PhasePoint{T<:AbstractVecOrMat{<:AbstractFloat},V<:DualValue}
5858
ℓκ::V # Cached neg kinect energy for the current r.
5959
function PhasePoint::T, r::T, ℓπ::V, ℓκ::V) where {T,V}
6060
@argcheck length(θ) == length(r) == length(ℓπ.gradient) == length(ℓπ.gradient)
61-
if any(isfinite.((θ, r, ℓπ, ℓκ)) .== false)
62-
# @warn "The current proposal will be rejected due to numerical error(s)." isfinite.((θ, r, ℓπ, ℓκ))
63-
# NOTE eltype has to be inlined to avoid type stability issue; see #267
61+
if !isfinite(ℓπ)
6462
ℓπ = DualValue(
65-
map(v -> isfinite(v) ? v : -eltype(T)(Inf), ℓπ.value),
63+
map(v -> isfinite(v) ? v : oftype(v, -Inf), ℓπ.value),
6664
ℓπ.gradient,
6765
)
66+
end
67+
if !isfinite(ℓκ)
6868
ℓκ = DualValue(
69-
map(v -> isfinite(v) ? v : -eltype(T)(Inf), ℓκ.value),
69+
map(v -> isfinite(v) ? v : oftype(v, -Inf), ℓκ.value),
7070
ℓκ.gradient,
7171
)
7272
end
@@ -105,7 +105,6 @@ end
105105

106106

107107
Base.isfinite(v::DualValue) = all(isfinite, v.value) && all(isfinite, v.gradient)
108-
Base.isfinite(v::AbstractVecOrMat) = all(isfinite, v)
109108
Base.isfinite(z::PhasePoint) = isfinite(z.ℓπ) && isfinite(z.ℓκ)
110109

111110
###

src/trajectory.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ function transition(
275275
hamiltonian_energy = H,
276276
hamiltonian_energy_error = H - H0,
277277
# check numerical error in proposed phase point.
278-
numerical_error = isfinite(H′),
278+
numerical_error = !all(isfinite, H′),
279279
),
280280
stat.integrator),
281281
)
@@ -296,13 +296,12 @@ function accept_phasepoint!(
296296
end
297297
function accept_phasepoint!(z::T, z′::T, is_accept) where {T<:PhasePoint{<:AbstractMatrix}}
298298
# Revert unaccepted proposals in `z′`
299-
is_reject = (!).(is_accept)
300-
if any(is_reject)
299+
if !all(is_accept)
301300
# Convert logical indexing to number indexing to support CUDA.jl
302301
# NOTE: for x::CuArray, x[:,Vector{Bool}] is NOT supported
303302
# x[:,CuVector{Int}] is NOT supported
304303
# x[:,Vector{Int}] is supported
305-
is_reject = findall(is_reject) |> Array
304+
is_reject = Vector(findall(!, is_accept))
306305
z′.θ[:, is_reject] = z.θ[:, is_reject]
307306
z′.r[:, is_reject] = z.r[:, is_reject]
308307
z′.ℓπ.value[is_reject] = z.ℓπ.value[is_reject]

test/sampler.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ function test_stats(
2323
:hamiltonian_energy_error,
2424
:is_adapt,
2525
)
26-
@test all(map(s -> in(name, propertynames(s)), stats))
26+
@test all(s -> in(name, propertynames(s)), stats)
2727
end
2828
is_adapts = getproperty.(stats, :is_adapt)
2929
@test is_adapts[1:n_adapts] == ones(Bool, n_adapts)
@@ -49,7 +49,7 @@ function test_stats(
4949
:tree_depth,
5050
:numerical_error,
5151
)
52-
@test all(map(s -> in(name, propertynames(s)), stats))
52+
@test all(s -> in(name, propertynames(s)), stats)
5353
end
5454
is_adapts = getproperty.(stats, :is_adapt)
5555
@test is_adapts[1:n_adapts] == ones(Bool, n_adapts)

0 commit comments

Comments
 (0)