Skip to content

Commit 83b917b

Browse files
committed
feat: Add low rank modified mass matrix adaptation
1 parent b1592db commit 83b917b

File tree

6 files changed

+340
-124
lines changed

6 files changed

+340
-124
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ notebooks/*.hpp
1111
perf.data*
1212
wheels
1313
.vscode/
14+
*~

Cargo.lock

Lines changed: 35 additions & 37 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ name = "_lib"
2222
crate-type = ["cdylib"]
2323

2424
[dependencies]
25-
nuts-rs = "0.11.0"
25+
nuts-rs = "0.12.0"
2626
numpy = "0.21.0"
2727
ndarray = "0.15.6"
2828
rand = "0.8.5"
@@ -33,7 +33,7 @@ rayon = "1.9.0"
3333
arrow = { version = "52.0.0", default-features = false, features = ["ffi"] }
3434
anyhow = "1.0.72"
3535
itertools = "0.13.0"
36-
bridgestan = "2.4.1"
36+
bridgestan = "2.5.0"
3737
rand_distr = "0.4.3"
3838
smallvec = "1.11.0"
3939
upon = { version = "0.8.1", default-features = false, features = [] }

python/nutpie/sample.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ def sample(
460460
seed: Optional[int],
461461
save_warmup: bool,
462462
progress_bar: bool,
463+
low_rank_modified_mass_matrix: bool = False,
463464
init_mean: Optional[np.ndarray],
464465
return_raw_trace: bool,
465466
blocking: Literal[True],
@@ -478,6 +479,7 @@ def sample(
478479
seed: Optional[int],
479480
save_warmup: bool,
480481
progress_bar: bool,
482+
low_rank_modified_mass_matrix: bool = False,
481483
init_mean: Optional[np.ndarray],
482484
return_raw_trace: bool,
483485
blocking: Literal[False],
@@ -495,6 +497,7 @@ def sample(
495497
seed: Optional[int] = None,
496498
save_warmup: bool = True,
497499
progress_bar: bool = True,
500+
low_rank_modified_mass_matrix: bool = False,
498501
init_mean: Optional[np.ndarray] = None,
499502
return_raw_trace: bool = False,
500503
blocking: bool = True,
@@ -569,6 +572,9 @@ def sample(
569572
for the progress bar (eg CSS).
570573
progress_rate: int, default=500
571574
Rate in ms at which the progress should be updated.
575+
low_rank_modified_mass_matrix: bool, default=False
576+
Allow adaptation to some posterior correlations using
577+
a low-rank updated mass matrix.
572578
**kwargs
573579
Pass additional arguments to nutpie._lib.PySamplerArgs
574580
@@ -577,7 +583,11 @@ def sample(
577583
trace : arviz.InferenceData
578584
An ArviZ ``InferenceData`` object that contains the samples.
579585
"""
580-
settings = _lib.PyDiagGradNutsSettings(seed)
586+
587+
if low_rank_modified_mass_matrix:
588+
settings = _lib.PyNutsSettings.LowRank(seed)
589+
else:
590+
settings = _lib.PyNutsSettings.Diag(seed)
581591
settings.num_tune = tune
582592
settings.num_draws = draws
583593
settings.num_chains = chains

src/pyfunc.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ impl PyVariable {
4141
let field = Arc::new(Field::new("item", DataType::Boolean, false));
4242
DataType::FixedSizeList(field, tensor_type.size() as i32)
4343
}
44-
ExpandDtype::ArrayFloat64 { tensor_type } => {
44+
ExpandDtype::ArrayFloat64 { tensor_type: _ } => {
4545
let field = Arc::new(Field::new("item", DataType::Float64, true));
4646
DataType::List(field)
4747
}
@@ -303,11 +303,11 @@ impl ExpandDtype {
303303
#[getter]
304304
fn shape(&self) -> Option<Vec<usize>> {
305305
match self {
306-
Self::BooleanArray {tensor_type} => { Some(tensor_type.shape.iter().cloned().collect()) },
307-
Self::ArrayFloat64 {tensor_type} => { Some(tensor_type.shape.iter().cloned().collect()) },
308-
Self::ArrayFloat32 {tensor_type} => { Some(tensor_type.shape.iter().cloned().collect()) },
309-
Self::ArrayInt64 {tensor_type} => { Some(tensor_type.shape.iter().cloned().collect()) },
310-
_ => { None },
306+
Self::BooleanArray { tensor_type } => Some(tensor_type.shape.iter().cloned().collect()),
307+
Self::ArrayFloat64 { tensor_type } => Some(tensor_type.shape.iter().cloned().collect()),
308+
Self::ArrayFloat32 { tensor_type } => Some(tensor_type.shape.iter().cloned().collect()),
309+
Self::ArrayInt64 { tensor_type } => Some(tensor_type.shape.iter().cloned().collect()),
310+
_ => None,
311311
}
312312
}
313313
}

0 commit comments

Comments
 (0)