Skip to content

Commit 1c4292c

Browse files
authored
Merge pull request #24 from probsys/20250226-fsaad-fixes
20250226 fsaad fixes
2 parents d8bfd03 + d130150 commit 1c4292c

File tree

5 files changed

+68
-33
lines changed

5 files changed

+68
-33
lines changed

docs/src/tutorials/overview.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@
890890
"source": [
891891
"It is also possible to directly access the underlying predictive distribution of new data at arbitrary time series values by using [`AutoGP.predict_mvn`](@ref), which returns an instance of [`Distributions.MixtureModel`](https://juliastats.org/Distributions.jl/stable/mixture/#Distributions.MixtureModel). The [`Distributions.MvNormal`](https://juliastats.org/Distributions.jl/stable/multivariate/#Distributions.MvNormal) object corresponding to each of the 7 particles in the mixture can be extracted using [`Distributions.components`](https://juliastats.org/Distributions.jl/stable/mixture/#Distributions.components-Tuple{AbstractMixtureModel}) and the weights extracted using [`Distributions.probs`](https://juliastats.org/Distributions.jl/stable/mixture/#Distributions.probs-Tuple{AbstractMixtureModel}).\n",
892892
"\n",
893-
"Each `MvNormal` in the mixture has 18 dimensions corresponding to the lenght of `df_test.ds`."
893+
"Each `MvNormal` in the mixture has 18 dimensions corresponding to the length of `df_test.ds`."
894894
]
895895
},
896896
{

docs/src/tutorials/overview.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ show(logps)
541541

542542
It is also possible to directly access the underlying predictive distribution of new data at arbitrary time series values by using [`AutoGP.predict_mvn`](@ref), which returns an instance of [`Distributions.MixtureModel`](https://juliastats.org/Distributions.jl/stable/mixture/#Distributions.MixtureModel). The [`Distributions.MvNormal`](https://juliastats.org/Distributions.jl/stable/multivariate/#Distributions.MvNormal) object corresponding to each of the 7 particles in the mixture can be extracted using [`Distributions.components`](https://juliastats.org/Distributions.jl/stable/mixture/#Distributions.components-Tuple{AbstractMixtureModel}) and the weights extracted using [`Distributions.probs`](https://juliastats.org/Distributions.jl/stable/mixture/#Distributions.probs-Tuple{AbstractMixtureModel}).
543543

544-
Each `MvNormal` in the mixture has 18 dimensions corresponding to the lenght of `df_test.ds`.
544+
Each `MvNormal` in the mixture has 18 dimensions corresponding to the length of `df_test.ds`.
545545

546546

547547
```julia

src/Callbacks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,11 @@ function make_smc_callback(fn::Function, model::AutoGP.GPModel; kwargs...)
115115
ds_permuted = model.ds[permutation]
116116
y_permuted = model.y[permutation]
117117

118-
# Remaining data.
118+
# Observed data.
119119
ds_obs = ds_permuted[1:step]
120120
y_obs = y_permuted[1:step]
121121

122-
# Future data.
122+
# Remaining data.
123123
ds_next = ds_permuted[step+1:end]
124124
y_next = y_permuted[step+1:end]
125125

src/api.jl

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
# limitations under the License.
1414

1515
import DataFrames
16-
import Dates
1716
import Distributions
1817
import Random
1918

2019
import Gen
2120

21+
using Dates
2222
using Match
2323

2424
using Distributions: MixtureModel
@@ -35,17 +35,18 @@ function seed!(seed)
3535
end
3636

3737
"""
38-
IndexType = Union{Vector{<:Real}, Vector{<:Dates.TimeType}}
38+
IndexType = Union{Vector{<:Real}, Vector{<:Date}, Vector{<:DateTime}}
3939
4040
Permitted Julia types for Gaussian process time points.
41-
`Real` numbers are ingested directly, treated as time points..
42-
Instances of `Dates.TimeType` are converted to numeric time points by using
41+
`Real` numbers are ingested directly, treated as time points.
42+
Instances of the `Dates` types are converted to numeric time points by using
4343
[`Dates.datetime2unix`](https://docs.julialang.org/en/v1/stdlib/Dates/#Dates.datetime2unix).
4444
"""
45-
const IndexType = Union{Vector{<:Real}, Vector{<:Dates.TimeType}}
45+
const IndexType = Union{Vector{<:Real}, Vector{<:Date}, Vector{<:DateTime}}
4646

47-
to_numeric(t::Vector{<:Dates.AbstractTime}) = @. Dates.datetime2unix(Dates.DateTime(t))
48-
to_numeric(t::Vector{<:Real}) = t
47+
to_numeric(t::DateTime) = datetime2unix(t)
48+
to_numeric(t::Date) = to_numeric(DateTime(t))
49+
to_numeric(t::Real) = t
4950

5051
"""
5152
struct GPModel
@@ -92,10 +93,10 @@ function GPModel(
9293
n_particles::Integer=Threads.nthreads(),
9394
config::GP.GPConfig=GP.GPConfig())
9495
# Save the transformations.
95-
ds_transform = Transforms.LinearTransform(to_numeric(ds), 0, 1)
96+
ds_transform = Transforms.LinearTransform(to_numeric.(ds), 0, 1)
9697
y_transform = Transforms.LinearTransform(Vector{Float64}(y), 1)
9798
# Transform the data.
98-
ds_numeric = Transforms.apply(ds_transform, to_numeric(ds))
99+
ds_numeric = Transforms.apply(ds_transform, to_numeric.(ds))
99100
y_numeric = Transforms.apply(y_transform, y)
100101
# Initialize the particle filter.
101102
observations = Gen.choicemap((:xs, y_numeric))
@@ -105,7 +106,7 @@ function GPModel(
105106
pf_state = Gen.initialize_particle_filter(
106107
Model.model, (ds_numeric, config), observations, n_particles)
107108
# Return the state.
108-
return GPModel(pf_state, config, ds, y, ds_transform, y_transform)
109+
return GPModel(pf_state, config, collect(ds), collect(y), ds_transform, y_transform)
109110
end
110111

111112
"""
@@ -170,8 +171,8 @@ the observed data. Inference is performed using sequential Monte Carlo.
170171
# Arguments
171172
- `model::GPModel`: Instance of the `GPModel` to use.
172173
- `schedule::Vector{<:Integer}`: Schedule for incorporating data for SMC, refer to [`Schedule`](@ref).
173-
- `n_mcmc::Int`: Number of involutive MCMC rejuvenation steps.
174-
- `n_hmc::Int`: Number of HMC steps per accepted involutive MCMC step.
174+
- `n_mcmc::Union{Integer,Vector{<:Integer}}`: Number of involutive MCMC rejuvenation steps. If vector, must have same length as `schedule`.
175+
- `n_hmc::Union{Integer,Vector{<:Integer}}`: Number of HMC steps per accepted involutive MCMC step. If vector, must have same length as `schedule`.
175176
- `biased::Bool`: Whether to bias the proposal to produce "short" structures.
176177
- `shuffle::Bool=true`: Whether to shuffle indexes `ds` or incorporate data in the given order.
177178
- `adaptive_resampling::Bool=true`: If `true` resamples based on ESS threshold, else at each step.
@@ -190,8 +191,8 @@ the observed data. Inference is performed using sequential Monte Carlo.
190191
function fit_smc!(
191192
model::GPModel;
192193
schedule::Vector{<:Integer},
193-
n_mcmc::Int,
194-
n_hmc::Int,
194+
n_mcmc::Union{Integer,Vector{<:Integer}},
195+
n_hmc::Union{Integer,Vector{<:Integer}},
195196
biased::Bool=false,
196197
shuffle::Bool=true,
197198
adaptive_resampling::Bool=true,
@@ -205,7 +206,7 @@ function fit_smc!(
205206
end
206207
# Obtain observed data.
207208
n = length(model.ds)
208-
ds_numeric = Transforms.apply(model.ds_transform, to_numeric(model.ds))
209+
ds_numeric = Transforms.apply(model.ds_transform, to_numeric.(model.ds))
209210
y_numeric = Transforms.apply(model.y_transform, model.y)
210211
permutation = shuffle ? Random.randperm(n) : collect(1:n)
211212
# Run SMC.
@@ -313,7 +314,7 @@ function fit_greedy!(
313314
!model.config.changepoints || error("AutoGP.fit_greedy! does not support changepoint operators.")
314315
(1 <= max_depth <= (model.config.max_depth == -1 ? Inf : model.config.max_depth)) || error("AutoGP.fit_greedy! requires positive and finite max_depth.")
315316
# Prepare observations.
316-
ds_numeric = Transforms.apply(model.ds_transform, to_numeric(model.ds))
317+
ds_numeric = Transforms.apply(model.ds_transform, to_numeric.(model.ds))
317318
y_numeric = Transforms.apply(model.y_transform, model.y)
318319
# Helper function for creating intermediate models for callback.
319320
make_greedy_submodel = (trace) -> begin
@@ -409,7 +410,30 @@ function add_data!(model::GPModel, ds::IndexType, y::Vector{<:Real})
409410
append!(model.ds, ds)
410411
append!(model.y, y)
411412
# Convert to numeric.
412-
ds_numeric = Transforms.apply(model.ds_transform, to_numeric(model.ds))
413+
ds_numeric = Transforms.apply(model.ds_transform, to_numeric.(model.ds))
414+
y_numeric = Transforms.apply(model.y_transform, model.y)
415+
# Prepare observations.
416+
observations = Gen.choicemap((:xs, y_numeric))
417+
!isnothing(model.config.noise) && (observations[:noise] = trace[:noise])
418+
# Run SMC step.
419+
Inference.smc_step!(model.pf_state, (ds_numeric, model.config), observations)
420+
end
421+
422+
"""
423+
remove_data!(model::GPModel, ds::IndexType, y::Vector{<:Real})
424+
Remove existing observations `ds` from `model`.
425+
"""
426+
function remove_data!(model::GPModel, ds::IndexType)
427+
# Find the data point.
428+
indexes = findall(x->x in ds, model.ds)
429+
if length(indexes) == 0
430+
error("No such time points $(ds).")
431+
end
432+
# Append the data.
433+
deleteat!(model.ds, indexes)
434+
deleteat!(model.y, indexes)
435+
# Convert to numeric.
436+
ds_numeric = Transforms.apply(model.ds_transform, to_numeric.(model.ds))
413437
y_numeric = Transforms.apply(model.y_transform, model.y)
414438
# Prepare observations.
415439
observations = Gen.choicemap((:xs, y_numeric))
@@ -452,7 +476,7 @@ function predict_mvn(
452476
if !(eltype(ds) <: eltype(model.ds))
453477
error("Invalid time $(ds), expected $(eltype(model.ds))")
454478
end
455-
ds_numeric = Transforms.apply(model.ds_transform, to_numeric(ds))
479+
ds_numeric = Transforms.apply(model.ds_transform, to_numeric.(ds))
456480
n_particles = num_particles(model)
457481
weights = particle_weights(model)
458482
distributions = Vector{MvNormal}(undef, n_particles)
@@ -584,11 +608,11 @@ function predict(
584608
if !(eltype(ds) <: eltype(model.ds))
585609
error("Invalid time $(ds), expected $(eltype(model.ds))")
586610
end
587-
ds_numeric = Transforms.apply(model.ds_transform, to_numeric(ds))
611+
ds_numeric = Transforms.apply(model.ds_transform, to_numeric.(ds))
588612
weights = particle_weights(model)
589613
n_particles = num_particles(model)
590614
frames = Vector(undef, n_particles)
591-
for i=1:n_particles
615+
Threads.@threads for i=1:n_particles
592616
y_mean, y_bounds = Inference.predict(
593617
model.pf_state.traces[i], ds_numeric;
594618
quantiles=quantiles, noise_pred=noise_pred)

src/inference_smc_anneal_data.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,9 @@ function run_smc_anneal_data(
145145
xs::Vector{Float64};
146146
config::GPConfig=GPConfig(),
147147
biased::Bool=false,
148-
n_particles::Int=4,
149-
n_mcmc::Int=10,
150-
n_hmc::Int=10,
148+
n_particles::Integer=4,
149+
n_mcmc::Union{Integer,Vector{<:Integer}}=10,
150+
n_hmc::Union{Integer,Vector{<:Integer}}=10,
151151
hmc_config=Dict(),
152152
permutation::Vector{<:Integer}=collect(1:length(ts)),
153153
schedule::Vector{<:Integer}=range(1:length(ts)),
@@ -170,6 +170,12 @@ function run_smc_anneal_data(
170170
@assert schedule[end] == length(ts)
171171
@assert all((schedule[2:end] .- schedule[1:end-1]) .> 0)
172172

173+
# Obtain n_mcmc and n_hmc.
174+
isa(n_mcmc, Integer) && begin n_mcmc = repeat([n_mcmc], length(schedule)) end
175+
isa(n_hmc, Integer) && begin n_hmc = repeat([n_hmc], length(schedule)) end
176+
@assert length(n_mcmc) == length(schedule)
177+
@assert length(n_hmc) == length(schedule)
178+
173179
# Initialize SMC particles from prior.
174180
@timeit elapsed begin
175181
observations = Gen.choicemap()
@@ -197,7 +203,7 @@ function run_smc_anneal_data(
197203
verbose=verbose)
198204

199205
# Run inference.
200-
for step in schedule
206+
for (i, step) in enumerate(schedule)
201207
verbose && println("Running SMC round $(step)/$(schedule[end])")
202208

203209
@timeit elapsed begin
@@ -231,13 +237,18 @@ function run_smc_anneal_data(
231237
rejuvenated = false
232238
if !adaptive_rejuvenation || resampled
233239
rejuvenated = true
234-
Threads.@threads for i=1:n_particles
235-
local trace = state.traces[i]
240+
Threads.@threads for p=1:n_particles
241+
local trace = state.traces[p]
236242
trace, = rejuvenate_particle_structure(
237-
trace, n_mcmc, n_hmc, biased;
238-
hmc_config=hmc_config, verbose=verbose, check=check,
243+
trace,
244+
n_mcmc[i],
245+
n_hmc[i],
246+
biased;
247+
hmc_config=hmc_config,
248+
verbose=verbose,
249+
check=check,
239250
observations=observations)
240-
state.traces[i] = trace
251+
state.traces[p] = trace
241252
end
242253
end
243254

0 commit comments

Comments
 (0)