Skip to content

Commit d859b82

Browse files
committed
refactor: clean up some left overs from new storage
1 parent 4e9a7f8 commit d859b82

File tree

14 files changed

+63
-152
lines changed

14 files changed

+63
-152
lines changed

nuts-derive/src/lib.rs

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -79,34 +79,34 @@ enum StorableField {
7979

8080
// Check if a type is a generic type parameter
8181
fn is_generic_param(ty: &Type, generics: &syn::Generics) -> bool {
82-
if let Type::Path(type_path) = ty {
83-
if type_path.path.segments.len() == 1 {
84-
let type_name = &type_path.path.segments.first().unwrap().ident;
85-
return generics.params.iter().any(|param| {
86-
if let GenericParam::Type(type_param) = param {
87-
&type_param.ident == type_name
88-
} else {
89-
false
90-
}
91-
});
92-
}
82+
if let Type::Path(type_path) = ty
83+
&& type_path.path.segments.len() == 1
84+
{
85+
let type_name = &type_path.path.segments.first().unwrap().ident;
86+
return generics.params.iter().any(|param| {
87+
if let GenericParam::Type(type_param) = param {
88+
&type_param.ident == type_name
89+
} else {
90+
false
91+
}
92+
});
9393
}
9494
false
9595
}
9696

9797
// Check if a type implements Storable trait based on bounds
9898
fn has_storable_bound(ty: &Ident, generics: &syn::Generics) -> bool {
9999
for param in &generics.params {
100-
if let GenericParam::Type(type_param) = param {
101-
if &type_param.ident == ty {
102-
for bound in &type_param.bounds {
103-
if let syn::TypeParamBound::Trait(trait_bound) = bound {
104-
let path = &trait_bound.path;
105-
if path.segments.len() == 1
106-
&& path.segments.first().unwrap().ident == "Storable"
107-
{
108-
return true;
109-
}
100+
if let GenericParam::Type(type_param) = param
101+
&& &type_param.ident == ty
102+
{
103+
for bound in &type_param.bounds {
104+
if let syn::TypeParamBound::Trait(trait_bound) = bound {
105+
let path = &trait_bound.path;
106+
if path.segments.len() == 1
107+
&& path.segments.first().unwrap().ident == "Storable"
108+
{
109+
return true;
110110
}
111111
}
112112
}
@@ -165,7 +165,7 @@ pub fn storable_derive(input: TokenStream) -> TokenStream {
165165
ty_str
166166
);
167167
};
168-
let item = if path.segments.first().unwrap().ident.to_string() == "Option" {
168+
let item = if path.segments.first().unwrap().ident == "Option" {
169169
if let PathArguments::AngleBracketed(AngleBracketedGenericArguments {
170170
args, ..
171171
}) = &path.segments.first().unwrap().arguments
@@ -539,7 +539,7 @@ pub fn storable_derive(input: TokenStream) -> TokenStream {
539539
});
540540

541541
let get_all_fn = quote! {
542-
fn get_all(&self, parent: &P) -> Vec<(&str, Option<nuts_storable::Value>)> {
542+
fn get_all<'a>(&'a mut self, parent: &'a P) -> Vec<(&'a str, Option<nuts_storable::Value>)> {
543543
let mut result = Vec::with_capacity(Self::names(parent).len());
544544
#(#get_all_exprs)*
545545
result

nuts-derive/tests/storable.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ fn test_storable() {
6161
value2: 8.0,
6262
draws2: vec![9.0, 2.0, 3.0],
6363
};
64-
let stats = ExampleStats {
64+
let mut stats = ExampleStats {
6565
step_size: 0.1,
6666
n_steps: 10,
6767
is_adapting: true,
@@ -72,7 +72,7 @@ fn test_storable() {
7272
_not_stored: "should not be stored".to_string(),
7373
};
7474

75-
let stats2: Example2<Parent, _> = Example2 {
75+
let mut stats2: Example2<Parent, _> = Example2 {
7676
field1: 42,
7777
field2: stats.clone(),
7878
_phantom: std::marker::PhantomData,

nuts-storable/src/lib.rs

Lines changed: 3 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -89,107 +89,7 @@ pub trait Storable<P: HasDims + ?Sized>: Send + Sync {
8989
fn item_type(parent: &P, item: &str) -> ItemType;
9090
fn dims<'a>(parent: &'a P, item: &str) -> Vec<&'a str>;
9191

92-
fn get_all(&self, parent: &P) -> Vec<(&str, Option<Value>)>;
93-
94-
fn get_f64(&self, parent: &P, name: &str) -> Option<f64> {
95-
self.get_all(parent)
96-
.into_iter()
97-
.find(|(item_name, _)| *item_name == name)
98-
.and_then(|(_, value)| match value {
99-
Some(Value::ScalarF64(v)) => Some(v),
100-
_ => None,
101-
})
102-
}
103-
104-
fn get_f32(&self, parent: &P, name: &str) -> Option<f32> {
105-
self.get_all(parent)
106-
.into_iter()
107-
.find(|(item_name, _)| *item_name == name)
108-
.and_then(|(_, value)| match value {
109-
Some(Value::ScalarF32(v)) => Some(v),
110-
_ => None,
111-
})
112-
}
113-
114-
fn get_u64(&self, parent: &P, name: &str) -> Option<u64> {
115-
self.get_all(parent)
116-
.into_iter()
117-
.find(|(item_name, _)| *item_name == name)
118-
.and_then(|(_, value)| match value {
119-
Some(Value::ScalarU64(v)) => Some(v),
120-
_ => None,
121-
})
122-
}
123-
124-
fn get_i64(&self, parent: &P, name: &str) -> Option<i64> {
125-
self.get_all(parent)
126-
.into_iter()
127-
.find(|(item_name, _)| *item_name == name)
128-
.and_then(|(_, value)| match value {
129-
Some(Value::ScalarI64(v)) => Some(v),
130-
_ => None,
131-
})
132-
}
133-
134-
fn get_bool(&self, parent: &P, name: &str) -> Option<bool> {
135-
self.get_all(parent)
136-
.into_iter()
137-
.find(|(item_name, _)| *item_name == name)
138-
.and_then(|(_, value)| match value {
139-
Some(Value::ScalarBool(v)) => Some(v),
140-
_ => None,
141-
})
142-
}
143-
144-
fn get_vec_f64(&self, parent: &P, name: &str) -> Option<Vec<f64>> {
145-
self.get_all(parent)
146-
.into_iter()
147-
.find(|(item_name, _)| *item_name == name)
148-
.and_then(|(_, value)| match value {
149-
Some(Value::F64(v)) => Some(v),
150-
_ => None,
151-
})
152-
}
153-
154-
fn get_vec_f32(&self, parent: &P, name: &str) -> Option<Vec<f32>> {
155-
self.get_all(parent)
156-
.into_iter()
157-
.find(|(item_name, _)| *item_name == name)
158-
.and_then(|(_, value)| match value {
159-
Some(Value::F32(v)) => Some(v),
160-
_ => None,
161-
})
162-
}
163-
164-
fn get_vec_u64(&self, parent: &P, name: &str) -> Option<Vec<u64>> {
165-
self.get_all(parent)
166-
.into_iter()
167-
.find(|(item_name, _)| *item_name == name)
168-
.and_then(|(_, value)| match value {
169-
Some(Value::U64(v)) => Some(v),
170-
_ => None,
171-
})
172-
}
173-
174-
fn get_vec_i64(&self, parent: &P, name: &str) -> Option<Vec<i64>> {
175-
self.get_all(parent)
176-
.into_iter()
177-
.find(|(item_name, _)| *item_name == name)
178-
.and_then(|(_, value)| match value {
179-
Some(Value::I64(v)) => Some(v),
180-
_ => None,
181-
})
182-
}
183-
184-
fn get_vec_bool(&self, parent: &P, name: &str) -> Option<Vec<bool>> {
185-
self.get_all(parent)
186-
.into_iter()
187-
.find(|(item_name, _)| *item_name == name)
188-
.and_then(|(_, value)| match value {
189-
Some(Value::Bool(v)) => Some(v),
190-
_ => None,
191-
})
192-
}
92+
fn get_all<'a>(&'a mut self, parent: &'a P) -> Vec<(&'a str, Option<Value>)>;
19393
}
19494

19595
impl<P: HasDims> Storable<P> for Vec<f64> {
@@ -205,7 +105,7 @@ impl<P: HasDims> Storable<P> for Vec<f64> {
205105
vec!["dim"]
206106
}
207107

208-
fn get_all(&self, _parent: &P) -> Vec<(&str, Option<Value>)> {
108+
fn get_all<'a>(&'a mut self, _parent: &'a P) -> Vec<(&'a str, Option<Value>)> {
209109
vec![("value", Some(Value::F64(self.clone())))]
210110
}
211111
}
@@ -223,7 +123,7 @@ impl<P: HasDims> Storable<P> for () {
223123
panic!("No items in unit type")
224124
}
225125

226-
fn get_all(&self, _parent: &P) -> Vec<(&str, Option<Value>)> {
126+
fn get_all(&mut self, _parent: &P) -> Vec<(&str, Option<Value>)> {
227127
vec![]
228128
}
229129
}

src/adapt_strategy.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ use nuts_storable::{HasDims, Storable};
55
use rand::Rng;
66
use serde::Serialize;
77

8-
use super::mass_matrix::MassMatrixAdaptStrategy;
98
use super::stepsize::AcceptanceRateCollector;
109
use super::stepsize::{StepSizeSettings, Strategy as StepSizeStrategy};
10+
use crate::mass_matrix::MassMatrixAdaptStrategy;
1111
use crate::{
1212
NutsError,
1313
chain::AdaptStrategy,

src/cpu_math.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ impl<F: CpuLogpFunc> CpuMath<F> {
2828
pub enum CpuMathError {
2929
#[error("Error during array operation")]
3030
ArrayError(),
31-
#[error("Error during point expansion")]
32-
ExpandError(),
31+
#[error("Error during point expansion: {0}")]
32+
ExpandError(String),
3333
}
3434

3535
impl<F: CpuLogpFunc> HasDims for CpuMath<F> {
@@ -57,7 +57,10 @@ impl<F: CpuLogpFunc> Storable<CpuMath<F>> for ExpandedVectorWrapper<F> {
5757
F::ExpandedVector::dims(&parent.logp_func, item)
5858
}
5959

60-
fn get_all(&self, parent: &CpuMath<F>) -> Vec<(&str, Option<nuts_storable::Value>)> {
60+
fn get_all<'a>(
61+
&'a mut self,
62+
parent: &'a CpuMath<F>,
63+
) -> Vec<(&'a str, Option<nuts_storable::Value>)> {
6164
self.0.get_all(&parent.logp_func)
6265
}
6366
}
@@ -138,7 +141,9 @@ impl<F: CpuLogpFunc> Math for CpuMath<F> {
138141
rng,
139142
array
140143
.try_as_col_major()
141-
.ok_or(CpuMathError::ExpandError())?
144+
.ok_or_else(|| {
145+
CpuMathError::ExpandError("Internal vector was not col major".into())
146+
})?
142147
.as_slice(),
143148
)?,
144149
))

src/mass_matrix/adapt.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use nuts_derive::Storable;
44
use rand::Rng;
55
use serde::Serialize;
66

7-
use super::mass_matrix::{DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance};
7+
use super::diagonal::{DiagMassMatrix, DrawGradCollector, MassMatrix, RunningVariance};
88
use crate::{
99
Math, NutsError,
1010
euclidean_hamiltonian::EuclideanPoint,
File renamed without changes.

src/mass_matrix/low_rank.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use nuts_derive::Storable;
66
use serde::Serialize;
77

88
use super::adapt::MassMatrixAdaptStrategy;
9-
use super::mass_matrix::{DrawGradCollector, MassMatrix};
9+
use super::diagonal::{DrawGradCollector, MassMatrix};
1010
use crate::{
1111
Math, NutsError, euclidean_hamiltonian::EuclideanPoint, hamiltonian::Point,
1212
sampler_stats::SamplerStats,

src/mass_matrix/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
mod adapt;
2+
mod diagonal;
23
mod low_rank;
3-
mod mass_matrix;
44

55
pub use adapt::DiagAdaptExpSettings;
66
pub(crate) use adapt::MassMatrixAdaptStrategy;
77
pub(crate) use adapt::Strategy;
8+
pub(crate) use diagonal::{DiagMassMatrix, MassMatrix};
89
pub use low_rank::LowRankSettings;
910
pub(crate) use low_rank::{LowRankMassMatrix, LowRankMassMatrixStrategy};
10-
pub(crate) use mass_matrix::{DiagMassMatrix, MassMatrix};

src/sampler.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ impl<T: TraceStorage> ChainProcess<T> {
626626

627627
let now = Instant::now();
628628
//let (point, info) = sampler.draw().unwrap();
629-
let (_point, draw_data, stats, info) = sampler.expanded_draw().unwrap();
629+
let (_point, mut draw_data, mut stats, info) = sampler.expanded_draw().unwrap();
630630

631631
let mut guard = chain_trace
632632
.lock()

0 commit comments

Comments
 (0)