Skip to content

Commit 1db17a0

Browse files
committed
feat: Improve error info for BadInitGrad
1 parent 19f653a commit 1db17a0

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

src/euclidean_hamiltonian.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,9 @@ impl<M: Math, Mass: MassMatrix<M>> Hamiltonian<M> for EuclideanHamiltonian<M, Ma
367367
.update_potential_gradient(math)
368368
.map_err(|e| NutsError::LogpFailure(Box::new(e)))?;
369369
if !math.array_all_finite_and_nonzero(&point.gradient) {
370-
Err(NutsError::BadInitGrad())
370+
Err(NutsError::BadInitGrad(
371+
anyhow::anyhow!("Invalid initial point").into(),
372+
))
371373
} else {
372374
Ok(state)
373375
}

src/nuts.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@ use crate::math_base::Math;
2121
#[non_exhaustive]
2222
#[derive(Error, Debug)]
2323
pub enum NutsError {
24-
#[error("Logp function returned error: {0}")]
24+
#[error("Logp function returned error: {0:?}")]
2525
LogpFailure(Box<dyn std::error::Error + Send + Sync>),
2626

2727
#[error("Could not serialize sample stats")]
2828
SerializeFailure(),
2929

30-
#[error("Could not initialize state because of bad initial gradient.")]
31-
BadInitGrad(),
30+
#[error("Could not initialize state because of bad initial gradient: {0:?}")]
31+
BadInitGrad(Box<dyn std::error::Error + Send + Sync>),
3232
}
3333

3434
pub type Result<T> = std::result::Result<T, NutsError>;

src/transformed_hamiltonian.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,10 @@ impl<M: Math> TransformedHamiltonian<M> {
219219
math.read_from_slice(&mut position_array, position);
220220
let _ = math
221221
.logp_array(&position_array, &mut gradient_array)
222-
.map_err(|_| NutsError::BadInitGrad())?;
222+
.map_err(|e| NutsError::BadInitGrad(Box::new(e)))?;
223223
let params = math
224224
.new_transformation(rng, &position_array, &gradient_array, chain)
225-
.map_err(|_| NutsError::BadInitGrad())?;
225+
.map_err(|e| NutsError::BadInitGrad(Box::new(e)))?;
226226
self.params = Some(params);
227227
Ok(())
228228
}
@@ -240,7 +240,7 @@ impl<M: Math> TransformedHamiltonian<M> {
240240
grads,
241241
self.params.as_mut().expect("Transformation was empty"),
242242
)
243-
.map_err(|_| NutsError::BadInitGrad())?;
243+
.map_err(|e| NutsError::BadInitGrad(Box::new(e)))?;
244244
Ok(())
245245
}
246246
}
@@ -407,7 +407,9 @@ impl<M: Math> Hamiltonian<M> for TransformedHamiltonian<M> {
407407
.map_err(|e| NutsError::LogpFailure(Box::new(e)))?;
408408

409409
if !point.is_valid(math) {
410-
Err(NutsError::BadInitGrad())
410+
Err(NutsError::BadInitGrad(
411+
anyhow::anyhow!("Invalid initial point").into(),
412+
))
411413
} else {
412414
Ok(state)
413415
}

0 commit comments

Comments
 (0)