Skip to content

Commit 8777eb7

Browse files
authored
Excise remaining Zygote legacy code (#132)
* Remove legacy collect statement * Remove Zygote-related remark * Remove and Zygote-related remark * Remove redundant comments * Remove outdated comment * Display progress through pseudo point tests * Excise zygote_friendly_map as it is redundant * Remove zygote-friendly map include from runtests * Bump patch * Update readme timings discussion
1 parent 9ddac9c commit 8777eb7

File tree

13 files changed

+22
-83
lines changed

13 files changed

+22
-83
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TemporalGPs"
22
uuid = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
33
authors = ["Will Tebbutt and contributors"]
4-
version = "0.7.0"
4+
version = "0.7.1"
55

66
[deps]
77
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,14 @@ This tells TemporalGPs that you want all parameters of `f` and anything derived
7676

7777

7878

79-
# Benchmarking Results
79+
# Benchmarking Results (Old)
8080

8181
![](/examples/benchmarks.png)
8282

8383
"naive" timings are with the usual [AbstractGPs.jl](https://https://github.com/JuliaGaussianProcesses/AbstractGPs.jl/) inference routines, and is the default implementation for GPs. "lgssm" timings are conducted using `to_sde` with no additional arguments. "static-lgssm" uses the `SArrayStorage(Float64)` option discussed above.
8484

85-
Gradient computations use Mooncake. Custom adjoints have been implemented to achieve this level of performance.
85+
Gradient computations were performed using [Zygote.jl](https://github.com/FluxML/Zygote.jl/), and required many custom adjoints.
86+
You should see similar results to this using [Mooncake.jl](https://github.com/compintell/Mooncake.jl) or [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl).
8687

8788

8889
# Relevant literature

src/TemporalGPs.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ module TemporalGPs
3535
# Various bits-and-bobs. Often commiting some type piracy.
3636
include(joinpath("util", "linear_algebra.jl"))
3737
include(joinpath("util", "scan.jl"))
38-
include(joinpath("util", "zygote_friendly_map.jl"))
3938

4039
include(joinpath("util", "gaussian.jl"))
4140
include(joinpath("util", "mul.jl"))

src/gp/lti_sde.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ end
118118
function lgssm_components(
119119
m::AbstractGPs.MeanFunction, k::Kernel, t::AbstractVector, storage_type::StorageType
120120
)
121-
m = collect(mean_vector(m, t)) # `collect` is needed as there are still issues with Zygote and FillArrays.
121+
m = mean_vector(m, t)
122122
As, as, Qs, (Hs, hs), x0 = lgssm_components(k, t, storage_type)
123123
hs = add_proj_mean(hs, m)
124124
return As, as, Qs, (Hs, hs), x0

src/models/lgssm.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,9 @@ end
193193
function posterior(prior::LGSSM, y::AbstractVector)
194194
_check_inputs(prior, y)
195195
new_trans, xf = _a_bit_of_posterior(prior, y)
196-
A = zygote_friendly_map(x -> x.A, new_trans)
197-
a = zygote_friendly_map(x -> x.a, new_trans)
198-
Q = zygote_friendly_map(x -> x.Q, new_trans)
196+
A = map(x -> x.A, new_trans)
197+
a = map(x -> x.a, new_trans)
198+
Q = map(x -> x.Q, new_trans)
199199
return LGSSM(GaussMarkovModel(reverse(ordering(prior)), A, a, Q, xf), prior.emissions)
200200
end
201201

src/models/missings.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
# Several strategies for missing data handling were attempted.
22
# 1. Use `missing`s as expected. This turned out to be problematic for type-stability.
3-
# 2. Sentinel values (NaNs). Also problematic for type-stability because Zygote.
3+
# 2. Sentinel values (NaNs).
44
# 3. (The adopted strategy) - replace missings with arbitrary observations and _large_
55
# observation noises. While not optimal, type-stability is preserved inside the
66
# performance-sensitive code.
7-
#
8-
# In an ideal world, strategy 1 would work. Unfortunately Zygote isn't up to it yet.
97

108
function AbstractGPs.logpdf(
119
model::LGSSM, y::AbstractVector{Union{Missing, T}},
@@ -28,7 +26,7 @@ function transform_model_and_obs(
2826
model::LGSSM, y::AbstractVector{<:Union{Missing, T}},
2927
) where {T<:Union{<:AbstractVector, <:Real}}
3028
Σs_filled_in, y_filled_in = fill_in_missings(
31-
zygote_friendly_map(noise_cov, emissions(model)), y,
29+
map(noise_cov, emissions(model)), y,
3230
)
3331
model_with_missings = replace_observation_noise_cov(model, Σs_filled_in)
3432
return model_with_missings, y_filled_in
@@ -54,11 +52,11 @@ function _logpdf_volume_compensation(y::AbstractVector{<:Union{Missing, <:Real}}
5452
return count(ismissing, y) * log(2π * _large_var_const()) / 2
5553
end
5654

57-
function fill_in_missings(Σs::Vector, y::AbstractVector{Union{Missing, T}}) where {T}
55+
function fill_in_missings(Σs::AbstractVector, y::AbstractVector{Union{Missing, T}}) where {T}
5856
return _fill_in_missings(Σs, y)
5957
end
6058

61-
function _fill_in_missings(Σs::Vector, y::AbstractVector{Union{Missing, T}}) where {T}
59+
function _fill_in_missings(Σs::AbstractVector, y::AbstractVector{Union{Missing, T}}) where {T}
6260

6361
# Fill in observation covariance matrices with very large values.
6462
Σs_filled_in = map(eachindex(y)) do n

src/space_time/pseudo_point.jl

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,21 +70,14 @@ function AbstractGPs.elbo(fx::FiniteLTISDE, y::AbstractVector, z_r::AbstractVect
7070

7171
k = fx_dtc.f.f.kernel
7272
Cf_diags = kernel_diagonals(k, fx_dtc.x)
73-
# return Cf_diags
7473

7574
# Transform a vector into a vector-of-vectors.
7675
y_vecs = restructure(y, lgssm.emissions)
77-
78-
tmp = zygote_friendly_map(
79-
((Σ, Cf_diag, marg_diag, yn), ) -> begin
80-
Σ_, _ = fill_in_missings(Σ, yn)
81-
return sum(diag(Σ_ \ (Cf_diag - marg_diag.P))) -
82-
count(ismissing, yn) + size(Σ_, 1)
83-
end,
84-
zip(Σs, Cf_diags, marg_diags, y_vecs),
85-
)
86-
# return -sum(tmp) / 2
87-
76+
tmp = map(Σs, Cf_diags, marg_diags, y_vecs) do Σ, Cf_diag, marg_diag, yn
77+
Σ_, _ = fill_in_missings(Σ, yn)
78+
return sum(diag(Σ_ \ (Cf_diag - marg_diag.P))) -
79+
count(ismissing, yn) + size(Σ_, 1)
80+
end
8881
return logpdf(lgssm, y_vecs) - sum(tmp) / 2
8982
end
9083

src/space_time/rectilinear_grid.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ end
9191
# See docstring elsewhere for context.
9292
function noise_var_to_time_form(x::RectilinearGrid, S::Diagonal{<:Real})
9393
vs = restructure(diag(S), Fill(length(get_space(x)), length(get_times(x))))
94-
return zygote_friendly_map(v -> Diagonal(collect(v)), vs)
94+
return map(v -> Diagonal(collect(v)), vs)
9595
end
9696

9797
destructure(::RectilinearGrid, y::AbstractVector) = reduce(vcat, y)

src/util/zygote_friendly_map.jl

Lines changed: 0 additions & 39 deletions
This file was deleted.

test/gp/lti_sde.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ using KernelFunctions: kappa
33
using TemporalGPs: build_lgssm, StorageType, is_of_storage_type, lgssm_components
44
using Test
55

6-
# Everything is tested once the LGSSM is constructed, so it is sufficient just to ensure
7-
# that Zygote can handle construction.
6+
# Everything is tested once the LGSSM is constructed, so the logpdf bit of this test
7+
# function is probably redundant. It is good to do a little bit of integration testing
8+
# though.
89
function _logpdf_tester(f_naive::GP, y, storage::StorageType, σ², t::AbstractVector)
910
f = to_sde(f_naive, storage)
1011
return logpdf(f(t, σ²...), y)

0 commit comments

Comments
 (0)