Skip to content

Commit 4147467

Browse files
committed
style: restructure packages
1 parent 414c944 commit 4147467

22 files changed

+59
-40
lines changed

examples/zarr_async_trace.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ use nuts_rs::{
2424
use nuts_storable::{HasDims, Value};
2525
use rand::Rng;
2626
use thiserror::Error;
27-
use zarrs::filesystem::FilesystemStore;
2827
use zarrs_object_store::AsyncObjectStore;
2928

3029
/// A multivariate normal distribution model

src/adapt_strategy.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@ use nuts_storable::{HasDims, Storable};
55
use rand::Rng;
66
use serde::Serialize;
77

8+
use super::mass_matrix::MassMatrixAdaptStrategy;
9+
use super::stepsize::AcceptanceRateCollector;
10+
use super::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy};
811
use crate::{
912
NutsError,
1013
chain::AdaptStrategy,
1114
euclidean_hamiltonian::EuclideanHamiltonian,
1215
hamiltonian::{DivergenceInfo, Hamiltonian, Point},
13-
mass_matrix_adapt::MassMatrixAdaptStrategy,
1416
math_base::Math,
1517
nuts::{Collector, NutsOptions},
1618
sampler_stats::{SamplerStats, StatsDims},
1719
state::State,
18-
stepsize_adapt::{StepSizeSettings, Strategy as StepSizeStrategy},
19-
stepsize_dual_avg::AcceptanceRateCollector,
2020
};
2121

2222
pub struct GlobalStrategy<M: Math, A: MassMatrixAdaptStrategy<M>> {
@@ -457,7 +457,7 @@ mod test {
457457

458458
#[test]
459459
fn instanciate_adaptive_sampler() {
460-
use crate::mass_matrix_adapt::Strategy;
460+
use crate::mass_matrix::Strategy;
461461

462462
let ndim = 10;
463463
let func = NormalLogp::new(ndim, 3.);

src/lib.rs

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -99,30 +99,20 @@
9999
mod adapt_strategy;
100100
mod chain;
101101
mod cpu_math;
102-
mod csv_storage;
103102
mod euclidean_hamiltonian;
104103
mod hamiltonian;
105-
mod hashmap_storage;
106-
mod low_rank_mass_matrix;
107104
mod mass_matrix;
108-
mod mass_matrix_adapt;
109105
mod math;
110106
mod math_base;
111107
mod model;
112-
#[cfg(feature = "ndarray")]
113-
mod ndarray_storage;
114108
mod nuts;
115109
mod sampler;
116110
mod sampler_stats;
117111
mod state;
118-
mod stepsize_adam;
119-
mod stepsize_adapt;
120-
mod stepsize_dual_avg;
112+
mod stepsize;
121113
mod storage;
122114
mod transform_adapt_strategy;
123115
mod transformed_hamiltonian;
124-
#[cfg(feature = "zarr")]
125-
mod zarr_storage;
126116

127117
pub use nuts_derive::Storable;
128118
pub use nuts_storable::{HasDims, ItemType, Storable, Value};
@@ -141,16 +131,15 @@ pub use sampler::{
141131
};
142132
pub use sampler_stats::SamplerStats;
143133

144-
pub use low_rank_mass_matrix::LowRankSettings;
145-
pub use mass_matrix_adapt::DiagAdaptExpSettings;
146-
pub use stepsize_adam::AdamOptions;
147-
pub use stepsize_adapt::{StepSizeAdaptMethod, StepSizeAdaptOptions, StepSizeSettings};
134+
pub use mass_matrix::DiagAdaptExpSettings;
135+
pub use mass_matrix::LowRankSettings;
136+
pub use stepsize::{AdamOptions, StepSizeAdaptMethod, StepSizeAdaptOptions, StepSizeSettings};
148137
pub use transform_adapt_strategy::TransformedSettings;
149138

150139
#[cfg(feature = "zarr")]
151-
pub use zarr_storage::{ZarrAsyncConfig, ZarrAsyncTraceStorage, ZarrConfig, ZarrTraceStorage};
140+
pub use storage::{ZarrAsyncConfig, ZarrAsyncTraceStorage, ZarrConfig, ZarrTraceStorage};
152141

153-
pub use csv_storage::{CsvConfig, CsvTraceStorage};
154-
pub use hashmap_storage::{HashMapConfig, HashMapValue};
142+
pub use storage::{CsvConfig, CsvTraceStorage};
143+
pub use storage::{HashMapConfig, HashMapValue};
155144
#[cfg(feature = "ndarray")]
156-
pub use ndarray_storage::{NdarrayConfig, NdarrayTrace, NdarrayValue};
145+
pub use storage::{NdarrayConfig, NdarrayTrace, NdarrayValue};

src/mass_matrix_adapt.rs renamed to src/mass_matrix/adapt.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ use nuts_derive::Storable;
44
use rand::Rng;
55
use serde::Serialize;
66

7+
use super::mass_matrix::{DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance};
78
use crate::{
89
Math, NutsError,
910
euclidean_hamiltonian::EuclideanPoint,
1011
hamiltonian::Point,
11-
mass_matrix::{DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance},
1212
nuts::{Collector, NutsOptions},
1313
sampler_stats::SamplerStats,
1414
};

src/low_rank_mass_matrix.rs renamed to src/mass_matrix/low_rank.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@ use itertools::Itertools;
55
use nuts_derive::Storable;
66
use serde::Serialize;
77

8+
use super::adapt::MassMatrixAdaptStrategy;
9+
use super::mass_matrix::{DrawGradCollector, MassMatrix};
810
use crate::{
9-
Math, NutsError,
10-
euclidean_hamiltonian::EuclideanPoint,
11-
hamiltonian::Point,
12-
mass_matrix::{DrawGradCollector, MassMatrix},
13-
mass_matrix_adapt::MassMatrixAdaptStrategy,
11+
Math, NutsError, euclidean_hamiltonian::EuclideanPoint, hamiltonian::Point,
1412
sampler_stats::SamplerStats,
1513
};
1614

@@ -507,9 +505,7 @@ mod test {
507505
use rand::{Rng, SeedableRng, rngs::SmallRng};
508506
use rand_distr::StandardNormal;
509507

510-
use crate::low_rank_mass_matrix::mat_all_finite;
511-
512-
use super::{estimate_mass_matrix, spd_mean};
508+
use super::{estimate_mass_matrix, mat_all_finite, spd_mean};
513509

514510
#[test]
515511
fn test_spd_mean() {
File renamed without changes.

src/mass_matrix/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
mod adapt;
2+
mod low_rank;
3+
mod mass_matrix;
4+
5+
pub use adapt::DiagAdaptExpSettings;
6+
pub(crate) use adapt::MassMatrixAdaptStrategy;
7+
pub(crate) use adapt::Strategy;
8+
pub use low_rank::LowRankSettings;
9+
pub(crate) use low_rank::{LowRankMassMatrix, LowRankMassMatrixStrategy};
10+
pub(crate) use mass_matrix::{DiagMassMatrix, MassMatrix};

src/sampler.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ use crate::{
2424
adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy, GlobalStrategyStatsOptions},
2525
chain::{AdaptStrategy, Chain, NutsChain, StatOptions},
2626
euclidean_hamiltonian::EuclideanHamiltonian,
27-
low_rank_mass_matrix::{LowRankMassMatrix, LowRankMassMatrixStrategy, LowRankSettings},
2827
mass_matrix::DiagMassMatrix,
29-
mass_matrix_adapt::Strategy as DiagMassMatrixStrategy,
28+
mass_matrix::Strategy as DiagMassMatrixStrategy,
29+
mass_matrix::{LowRankMassMatrix, LowRankMassMatrixStrategy, LowRankSettings},
3030
math_base::Math,
3131
model::Model,
3232
nuts::NutsOptions,
File renamed without changes.

src/stepsize_adapt.rs renamed to src/stepsize/adapt.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ use rand::Rng;
44
use rand_distr::Uniform;
55
use serde::Serialize;
66

7+
use super::adam::{Adam, AdamOptions};
8+
use super::dual_avg::{AcceptanceRateCollector, DualAverage, DualAverageOptions};
79
use crate::{
810
Math, NutsError,
911
hamiltonian::{Direction, Hamiltonian, LeapfrogResult, Point},
1012
nuts::{Collector, NutsOptions},
1113
sampler_stats::SamplerStats,
12-
stepsize_adam::{Adam, AdamOptions},
13-
stepsize_dual_avg::{AcceptanceRateCollector, DualAverage, DualAverageOptions},
1414
};
1515
use std::fmt::Debug;
1616

0 commit comments

Comments
 (0)