Skip to content

Commit 62af0c4

Browse files
committed
style: restructure packages
1 parent 0bdfef3 commit 62af0c4

22 files changed

+59
-40
lines changed

examples/zarr_async_trace.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ use nuts_rs::{
2424
use nuts_storable::{HasDims, Value};
2525
use rand::Rng;
2626
use thiserror::Error;
27-
use zarrs::filesystem::FilesystemStore;
2827
use zarrs_object_store::AsyncObjectStore;
2928

3029
/// A multivariate normal distribution model

src/adapt_strategy.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@ use nuts_storable::{HasDims, Storable};
55
use rand::Rng;
66
use serde::Serialize;
77

8+
use super::mass_matrix::MassMatrixAdaptStrategy;
9+
use super::stepsize::AcceptanceRateCollector;
10+
use super::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy};
811
use crate::{
912
NutsError,
1013
chain::AdaptStrategy,
1114
euclidean_hamiltonian::EuclideanHamiltonian,
1215
hamiltonian::{DivergenceInfo, Hamiltonian, Point},
13-
mass_matrix_adapt::MassMatrixAdaptStrategy,
1416
math_base::Math,
1517
nuts::{Collector, NutsOptions},
1618
sampler_stats::{SamplerStats, StatsDims},
1719
state::State,
18-
stepsize_adapt::{StepSizeSettings, Strategy as StepSizeStrategy},
19-
stepsize_dual_avg::AcceptanceRateCollector,
2020
};
2121

2222
pub struct GlobalStrategy<M: Math, A: MassMatrixAdaptStrategy<M>> {
@@ -457,7 +457,7 @@ mod test {
457457

458458
#[test]
459459
fn instanciate_adaptive_sampler() {
460-
use crate::mass_matrix_adapt::Strategy;
460+
use crate::mass_matrix::Strategy;
461461

462462
let ndim = 10;
463463
let func = NormalLogp::new(ndim, 3.);

src/lib.rs

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -101,30 +101,20 @@
101101
mod adapt_strategy;
102102
mod chain;
103103
mod cpu_math;
104-
mod csv_storage;
105104
mod euclidean_hamiltonian;
106105
mod hamiltonian;
107-
mod hashmap_storage;
108-
mod low_rank_mass_matrix;
109106
mod mass_matrix;
110-
mod mass_matrix_adapt;
111107
mod math;
112108
mod math_base;
113109
mod model;
114-
#[cfg(feature = "ndarray")]
115-
mod ndarray_storage;
116110
mod nuts;
117111
mod sampler;
118112
mod sampler_stats;
119113
mod state;
120-
mod stepsize_adam;
121-
mod stepsize_adapt;
122-
mod stepsize_dual_avg;
114+
mod stepsize;
123115
mod storage;
124116
mod transform_adapt_strategy;
125117
mod transformed_hamiltonian;
126-
#[cfg(feature = "zarr")]
127-
mod zarr_storage;
128118

129119
pub use nuts_derive::Storable;
130120
pub use nuts_storable::{HasDims, ItemType, Storable, Value};
@@ -143,16 +133,15 @@ pub use sampler::{
143133
};
144134
pub use sampler_stats::SamplerStats;
145135

146-
pub use low_rank_mass_matrix::LowRankSettings;
147-
pub use mass_matrix_adapt::DiagAdaptExpSettings;
148-
pub use stepsize_adam::AdamOptions;
149-
pub use stepsize_adapt::{StepSizeAdaptMethod, StepSizeAdaptOptions, StepSizeSettings};
136+
pub use mass_matrix::DiagAdaptExpSettings;
137+
pub use mass_matrix::LowRankSettings;
138+
pub use stepsize::{AdamOptions, StepSizeAdaptMethod, StepSizeAdaptOptions, StepSizeSettings};
150139
pub use transform_adapt_strategy::TransformedSettings;
151140

152141
#[cfg(feature = "zarr")]
153-
pub use zarr_storage::{ZarrAsyncConfig, ZarrAsyncTraceStorage, ZarrConfig, ZarrTraceStorage};
142+
pub use storage::{ZarrAsyncConfig, ZarrAsyncTraceStorage, ZarrConfig, ZarrTraceStorage};
154143

155-
pub use csv_storage::{CsvConfig, CsvTraceStorage};
156-
pub use hashmap_storage::{HashMapConfig, HashMapValue};
144+
pub use storage::{CsvConfig, CsvTraceStorage};
145+
pub use storage::{HashMapConfig, HashMapValue};
157146
#[cfg(feature = "ndarray")]
158-
pub use ndarray_storage::{NdarrayConfig, NdarrayTrace, NdarrayValue};
147+
pub use storage::{NdarrayConfig, NdarrayTrace, NdarrayValue};

src/mass_matrix_adapt.rs renamed to src/mass_matrix/adapt.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ use nuts_derive::Storable;
44
use rand::Rng;
55
use serde::Serialize;
66

7+
use super::mass_matrix::{DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance};
78
use crate::{
89
Math, NutsError,
910
euclidean_hamiltonian::EuclideanPoint,
1011
hamiltonian::Point,
11-
mass_matrix::{DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance},
1212
nuts::{Collector, NutsOptions},
1313
sampler_stats::SamplerStats,
1414
};

src/low_rank_mass_matrix.rs renamed to src/mass_matrix/low_rank.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,10 @@ use itertools::Itertools;
55
use nuts_derive::Storable;
66
use serde::Serialize;
77

8+
use super::adapt::MassMatrixAdaptStrategy;
9+
use super::mass_matrix::{DrawGradCollector, MassMatrix};
810
use crate::{
9-
Math, NutsError,
10-
euclidean_hamiltonian::EuclideanPoint,
11-
hamiltonian::Point,
12-
mass_matrix::{DrawGradCollector, MassMatrix},
13-
mass_matrix_adapt::MassMatrixAdaptStrategy,
11+
Math, NutsError, euclidean_hamiltonian::EuclideanPoint, hamiltonian::Point,
1412
sampler_stats::SamplerStats,
1513
};
1614

@@ -507,9 +505,7 @@ mod test {
507505
use rand::{Rng, SeedableRng, rngs::SmallRng};
508506
use rand_distr::StandardNormal;
509507

510-
use crate::low_rank_mass_matrix::mat_all_finite;
511-
512-
use super::{estimate_mass_matrix, spd_mean};
508+
use super::{estimate_mass_matrix, mat_all_finite, spd_mean};
513509

514510
#[test]
515511
fn test_spd_mean() {
File renamed without changes.

src/mass_matrix/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
mod adapt;
2+
mod low_rank;
3+
mod mass_matrix;
4+
5+
pub use adapt::DiagAdaptExpSettings;
6+
pub(crate) use adapt::MassMatrixAdaptStrategy;
7+
pub(crate) use adapt::Strategy;
8+
pub use low_rank::LowRankSettings;
9+
pub(crate) use low_rank::{LowRankMassMatrix, LowRankMassMatrixStrategy};
10+
pub(crate) use mass_matrix::{DiagMassMatrix, MassMatrix};

src/sampler.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ use crate::{
2424
adapt_strategy::{EuclideanAdaptOptions, GlobalStrategy, GlobalStrategyStatsOptions},
2525
chain::{AdaptStrategy, Chain, NutsChain, StatOptions},
2626
euclidean_hamiltonian::EuclideanHamiltonian,
27-
low_rank_mass_matrix::{LowRankMassMatrix, LowRankMassMatrixStrategy, LowRankSettings},
2827
mass_matrix::DiagMassMatrix,
29-
mass_matrix_adapt::Strategy as DiagMassMatrixStrategy,
28+
mass_matrix::Strategy as DiagMassMatrixStrategy,
29+
mass_matrix::{LowRankMassMatrix, LowRankMassMatrixStrategy, LowRankSettings},
3030
math_base::Math,
3131
model::Model,
3232
nuts::NutsOptions,
File renamed without changes.

src/stepsize_adapt.rs renamed to src/stepsize/adapt.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ use rand::Rng;
44
use rand_distr::Uniform;
55
use serde::Serialize;
66

7+
use super::adam::{Adam, AdamOptions};
8+
use super::dual_avg::{AcceptanceRateCollector, DualAverage, DualAverageOptions};
79
use crate::{
810
Math, NutsError,
911
hamiltonian::{Direction, Hamiltonian, LeapfrogResult, Point},
1012
nuts::{Collector, NutsOptions},
1113
sampler_stats::SamplerStats,
12-
stepsize_adam::{Adam, AdamOptions},
13-
stepsize_dual_avg::{AcceptanceRateCollector, DualAverage, DualAverageOptions},
1414
};
1515
use std::fmt::Debug;
1616

0 commit comments

Comments
 (0)