Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/progress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ impl ProgressHandler {
let progress =
progress_to_value(progress_update_count, self.n_cores, time_sampling, progress);
let rendered = template.render_from(&self.engine, &progress).to_string();
let rendered = rendered.unwrap_or_else(|err| format!("{}", err));
let rendered = rendered.unwrap_or_else(|err| format!("{err}"));
let _ = Python::with_gil(|py| self.callback.call1(py, (rendered,)));
progress_update_count += 1;
};
Expand Down
15 changes: 7 additions & 8 deletions src/pyfunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,8 @@ impl LogpError for PyLogpError {
let Ok(attr) = err.value(py).getattr("is_recoverable") else {
return false;
};
return attr
.is_truthy()
.expect("Could not access is_recoverable in error check");
attr.is_truthy()
.expect("Could not access is_recoverable in error check")
}),
Self::ReturnTypeError() => false,
Self::NotContiguousError(_) => false,
Expand All @@ -151,7 +150,7 @@ impl PyDensity {
transform_adapter: Option<&PyTransformAdapt>,
) -> Result<Self> {
let logp_func = Python::with_gil(|py| logp_clone_func.call0(py))?;
let transform_adapter = transform_adapter.map(|val| val.clone());
let transform_adapter = transform_adapter.cloned();
Ok(Self {
logp: logp_func,
transform_adapter,
Expand Down Expand Up @@ -185,7 +184,7 @@ impl CpuLogpFunc for PyDensity {
);
Ok(logp_val)
}
Err(err) => return Err(PyLogpError::PyError(err)),
Err(err) => Err(PyLogpError::PyError(err)),
}
})
}
Expand Down Expand Up @@ -359,7 +358,7 @@ impl TensorShape {
Self { shape, dims, size }
}
pub fn size(&self) -> usize {
return self.size;
self.size
}
}

Expand Down Expand Up @@ -617,14 +616,14 @@ impl Model for PyModel {
settings: &'model S,
) -> Result<Self::DrawStorage<'model, S>> {
let draws = settings.hint_num_tune() + settings.hint_num_draws();
Ok(PyTrace::new(
PyTrace::new(
rng,
chain_id,
self.variables.clone(),
&self.make_expand_func,
draws,
)
.context("Could not create PyTrace object")?)
.context("Could not create PyTrace object")
}

fn math(&self) -> Result<Self::Math<'_>> {
Expand Down
4 changes: 2 additions & 2 deletions src/pymc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ impl LogpError for ErrorCode {
}
}

impl<'a> CpuLogpFunc for &'a LogpFunc {
impl CpuLogpFunc for &LogpFunc {
type LogpError = ErrorCode;
type TransformParams = ();

Expand Down Expand Up @@ -175,7 +175,7 @@ impl<'model> DrawStorage for PyMcTrace<'model> {
let num_arrays = data.len() / size;
let data = Float64Array::from(data);
let item_field = Arc::new(Field::new("item", DataType::Float64, false));
let offsets = OffsetBuffer::from_lengths((0..num_arrays).into_iter().map(|_| size));
let offsets = OffsetBuffer::from_lengths((0..num_arrays).map(|_| size));
let array = LargeListArray::new(item_field.clone(), offsets, Arc::new(data), None);
let field = Field::new(name, DataType::LargeList(item_field), false);
(Arc::new(field), Arc::new(array) as Arc<dyn Array>)
Expand Down
Loading
Loading