Skip to content

Commit 2fe6539

Browse files
committed
feat: bring back Storage.inspect
1 parent 68c08ba commit 2fe6539

File tree

8 files changed

+151
-30
lines changed

8 files changed

+151
-30
lines changed

src/sampler.rs

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -693,11 +693,13 @@ enum SamplerCommand {
693693
Continue,
694694
Progress,
695695
Flush,
696+
Inspect,
696697
}
697698

698-
enum SamplerResponse {
699+
enum SamplerResponse<T: Send + 'static> {
699700
Ok(),
700701
Progress(Box<[ChainProgress]>),
702+
Inspect(T),
701703
}
702704

703705
pub enum SamplerWaitResult<F: Send + 'static> {
@@ -709,7 +711,7 @@ pub enum SamplerWaitResult<F: Send + 'static> {
709711
pub struct Sampler<F: Send + 'static> {
710712
main_thread: JoinHandle<Result<(Option<anyhow::Error>, F)>>,
711713
commands: SyncSender<SamplerCommand>,
712-
responses: Receiver<SamplerResponse>,
714+
responses: Receiver<SamplerResponse<(Option<anyhow::Error>, F)>>,
713715
results: Receiver<Result<()>>,
714716
}
715717

@@ -827,7 +829,11 @@ impl<F: Send + 'static> Sampler<F> {
827829
pause_start = Instant::now();
828830
}
829831
is_paused = true;
830-
responses_tx.send(SamplerResponse::Ok())?;
832+
responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
833+
anyhow::anyhow!(
834+
"Could not send pause response to controller thread: {e}"
835+
)
836+
})?;
831837
}
832838
Ok(SamplerCommand::Continue) => {
833839
for chain in chains.iter() {
@@ -837,18 +843,50 @@ impl<F: Send + 'static> Sampler<F> {
837843
}
838844
pause_time += pause_start.elapsed();
839845
is_paused = false;
840-
responses_tx.send(SamplerResponse::Ok())?;
846+
responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
847+
anyhow::anyhow!(
848+
"Could not send continue response to controller thread: {e}"
849+
)
850+
})?;
841851
}
842852
Ok(SamplerCommand::Progress) => {
843853
let progress =
844854
chains.iter().map(|chain| chain.progress()).collect_vec();
845-
responses_tx.send(SamplerResponse::Progress(progress.into()))?;
855+
responses_tx.send(SamplerResponse::Progress(progress.into())).map_err(|e| {
856+
anyhow::anyhow!(
857+
"Could not send progress response to controller thread: {e}"
858+
)
859+
})?;
860+
}
861+
Ok(SamplerCommand::Inspect) => {
862+
let traces = chains
863+
.iter()
864+
.map(|chain| {
865+
chain
866+
.trace
867+
.lock()
868+
.expect("Poisoned lock")
869+
.as_ref()
870+
.map(|v| v.inspect())
871+
})
872+
.flatten()
873+
.collect_vec();
874+
let finalized_trace = trace.inspect(traces)?;
875+
responses_tx.send(SamplerResponse::Inspect(finalized_trace)).map_err(|e| {
876+
anyhow::anyhow!(
877+
"Could not send inspect response to controller thread: {e}"
878+
)
879+
})?;
846880
}
847881
Ok(SamplerCommand::Flush) => {
848882
for chain in chains.iter() {
849883
chain.flush()?;
850884
}
851-
responses_tx.send(SamplerResponse::Ok())?;
885+
responses_tx.send(SamplerResponse::Ok()).map_err(|e| {
886+
anyhow::anyhow!(
887+
"Could not send flush response to controller thread: {e}"
888+
)
889+
})?;
852890
}
853891
Err(RecvTimeoutError::Timeout) => {}
854892
Err(RecvTimeoutError::Disconnected) => {
@@ -919,6 +957,18 @@ impl<F: Send + 'static> Sampler<F> {
919957
Ok(())
920958
}
921959

960+
pub fn inspect(&mut self) -> Result<(Option<anyhow::Error>, F)> {
961+
self.commands.send(SamplerCommand::Inspect)?;
962+
let response = self
963+
.responses
964+
.recv()
965+
.context("Could not recieve inspect response from controller thread")?;
966+
let SamplerResponse::Inspect(trace) = response else {
967+
bail!("Got invalid response from sample controller thread");
968+
};
969+
Ok(trace)
970+
}
971+
922972
pub fn abort(self) -> Result<(Option<anyhow::Error>, F)> {
923973
drop(self.commands);
924974
let result = self.main_thread.join();

src/storage/core.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ pub trait ChainStorage: Send {
2323
/// Finalizes the storage and returns processed results.
2424
fn finalize(self) -> Result<Self::Finalized>;
2525

26+
fn inspect(&self) -> Result<Option<Self::Finalized>> {
27+
Ok(None)
28+
}
29+
2630
/// Flush any buffered data to ensure all samples are stored.
2731
fn flush(&self) -> Result<()>;
2832
}
@@ -63,4 +67,9 @@ pub trait TraceStorage: Send + Sync + Sized + 'static {
6367
self,
6468
traces: Vec<Result<<Self::ChainStorage as ChainStorage>::Finalized>>,
6569
) -> Result<(Option<anyhow::Error>, Self::Finalized)>;
70+
71+
fn inspect(
72+
&self,
73+
traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
74+
) -> Result<(Option<anyhow::Error>, Self::Finalized)>;
6675
}

src/storage/csv.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,12 @@ impl ChainStorage for CsvChainStorage {
348348
// In practice, the buffer will be flushed when the file is closed
349349
Ok(())
350350
}
351+
352+
fn inspect(&self) -> Result<Option<Self::Finalized>> {
353+
// For CSV storage, inspection does not produce a finalized result
354+
self.flush()?;
355+
Ok(None)
356+
}
351357
}
352358

353359
impl StorageConfig for CsvConfig {
@@ -599,6 +605,19 @@ impl TraceStorage for CsvTraceStorage {
599605
}
600606
Ok((None, ()))
601607
}
608+
609+
fn inspect(
610+
&self,
611+
traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
612+
) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
613+
// Check for any errors in the chain inspections
614+
for trace_result in traces {
615+
if let Err(err) = trace_result {
616+
return Ok((Some(err), ()));
617+
}
618+
}
619+
Ok((None, ()))
620+
}
602621
}
603622

604623
#[cfg(test)]

src/storage/hashmap.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,14 @@ impl HashMapValue {
5151
}
5252

5353
/// Main storage for HashMap MCMC traces
54+
#[derive(Clone)]
5455
pub struct HashMapTraceStorage {
5556
draw_types: Vec<(String, ItemType)>,
5657
param_types: Vec<(String, ItemType)>,
5758
}
5859

5960
/// Per-chain storage for HashMap MCMC traces
61+
#[derive(Clone)]
6062
pub struct HashMapChainStorage {
6163
warmup_stats: HashMap<String, HashMapValue>,
6264
sample_stats: HashMap<String, HashMapValue>,
@@ -251,6 +253,10 @@ impl ChainStorage for HashMapChainStorage {
251253
fn flush(&self) -> Result<()> {
252254
Ok(())
253255
}
256+
257+
fn inspect(&self) -> Result<Option<Self::Finalized>> {
258+
self.clone().finalize().map(Some)
259+
}
254260
}
255261

256262
pub struct HashMapConfig {}
@@ -314,4 +320,12 @@ impl TraceStorage for HashMapTraceStorage {
314320

315321
Ok((first_error, results))
316322
}
323+
324+
fn inspect(
325+
&self,
326+
traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
327+
) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
328+
self.clone()
329+
.finalize(traces.into_iter().map(|r| r.map(|o| o.unwrap())).collect())
330+
}
317331
}

src/storage/ndarray.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ struct SharedArrays {
136136
}
137137

138138
/// Main storage for ndarray MCMC traces
139+
#[derive(Clone)]
139140
pub struct NdarrayTraceStorage {
140141
shared_arrays: Arc<Mutex<SharedArrays>>,
141142
}
@@ -348,4 +349,20 @@ impl TraceStorage for NdarrayTraceStorage {
348349

349350
Ok((first_error, result))
350351
}
352+
353+
fn inspect(
354+
&self,
355+
traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
356+
) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
357+
self.clone().finalize(
358+
traces
359+
.into_iter()
360+
.map(|res| match res {
361+
Ok(Some(_)) => Ok(()),
362+
Ok(None) => Ok(()),
363+
Err(err) => Err(err),
364+
})
365+
.collect(),
366+
)
367+
}
351368
}

src/storage/zarr/async_impl.rs

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,10 @@ async fn store_coords(
143143
_ => panic!("Unsupported coordinate type for {}", name),
144144
};
145145
let name: &String = name;
146-
let coord_array = ArrayBuilder::new(
147-
vec![len as u64],
148-
data_type,
149-
vec![len as u64].try_into().expect("Invalid chunk size"),
150-
fill_value,
151-
)
152-
.dimension_names(Some(vec![name.to_string()]))
153-
.build(store.clone(), &format!("{}/{}", group, name))?;
146+
let coord_array =
147+
ArrayBuilder::new(vec![len as u64], vec![len as u64], data_type, fill_value)
148+
.dimension_names(Some(vec![name.to_string()]))
149+
.build(store.clone(), &format!("{}/{}", group, name))?;
154150
let subset = vec![0];
155151
match coord {
156152
&Value::F64(ref v) => {
@@ -648,4 +644,16 @@ impl TraceStorage for ZarrAsyncTraceStorage {
648644
}
649645
Ok((None, ()))
650646
}
647+
648+
fn inspect(
649+
&self,
650+
traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
651+
) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
652+
for trace in traces {
653+
if let Err(err) = trace {
654+
return Ok((Some(err), ()));
655+
};
656+
}
657+
Ok((None, ()))
658+
}
651659
}

src/storage/zarr/common.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -209,14 +209,9 @@ pub fn create_arrays<TStorage: ?Sized>(
209209
.chain(std::iter::once(draw_chunk_size))
210210
.chain(extra_shape)
211211
.collect();
212-
let array = ArrayBuilder::new(
213-
shape,
214-
zarr_type,
215-
grid.try_into().expect("Invalid chunk sizes"),
216-
fill_value,
217-
)
218-
.dimension_names(Some(dims))
219-
.build(store.clone(), &format!("{}/{}", group_path, name))?;
212+
let array = ArrayBuilder::new(shape, grid, zarr_type, fill_value)
213+
.dimension_names(Some(dims))
214+
.build(store.clone(), &format!("{}/{}", group_path, name))?;
220215
arrays.insert(name.to_string(), array);
221216
}
222217
Ok(arrays)

src/storage/zarr/sync_impl.rs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,11 @@ pub fn store_coords(
4040
_ => panic!("Unsupported coordinate type for {}", name),
4141
};
4242
let name: &String = name;
43-
let coord_array = ArrayBuilder::new(
44-
vec![len as u64],
45-
data_type,
46-
vec![len as u64].try_into().expect("Invalid chunk size"),
47-
fill_value,
48-
)
49-
.dimension_names(Some(vec![name.to_string()]))
50-
.build(store.clone(), &format!("{}/{}", group, name))?;
43+
44+
let coord_array =
45+
ArrayBuilder::new(vec![len as u64], vec![len as u64], data_type, fill_value)
46+
.dimension_names(Some(vec![name.to_string()]))
47+
.build(store.clone(), &format!("{}/{}", group, name))?;
5148
let subset = vec![0];
5249
match coord {
5350
&Value::F64(ref v) => coord_array.store_chunk_elements::<f64>(&subset, v)?,
@@ -534,4 +531,16 @@ impl TraceStorage for ZarrTraceStorage {
534531
}
535532
Ok((None, ()))
536533
}
534+
535+
fn inspect(
536+
&self,
537+
traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
538+
) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
539+
for trace in traces {
540+
if let Err(err) = trace {
541+
return Ok((Some(err), ()));
542+
};
543+
}
544+
Ok((None, ()))
545+
}
537546
}

0 commit comments

Comments
 (0)