Skip to content

Commit 8d23cf0

Browse files
authored
Fix/error handling (#25)
* fix error handling with proper thiserror use * version bump for patch
1 parent be43b4d commit 8d23cf0

File tree

5 files changed

+89
-175
lines changed

5 files changed

+89
-175
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ license = "AGPL-3.0"
99
name = "koho"
1010
readme = "README.md"
1111
repository = "https://github.com/TheMesocarp/koho"
12-
version = "0.1.0"
12+
version = "0.1.1"
1313

1414

1515
[dependencies]

src/error.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ pub enum KohoError {
2222
NoRestrictionDefined,
2323
#[error("Sections for layer {i} are not yet provided")]
2424
NoSections { i: usize },
25-
#[error("Candle module error")]
26-
Candle(Error),
27-
#[error("Misc")]
28-
Msg(String),
25+
#[error("Candle module error {0}")]
26+
Candle(#[from] Error),
2927
}

src/lib.rs

Lines changed: 24 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,7 @@ impl SheafNN {
110110
) -> Result<TrainingMetrics, KohoError> {
111111
// Create the optimizer
112112
let mut optimizer =
113-
create_optimizer(optimizer_kind, self.parameters_mut(), lr, optimizer_params)
114-
.map_err(KohoError::Candle)?;
113+
create_optimizer(optimizer_kind, self.parameters_mut(), lr, optimizer_params)?;
115114

116115
let mut metrics = TrainingMetrics::new(epochs);
117116

@@ -123,21 +122,16 @@ impl SheafNN {
123122
let output = self.forward(input.clone(), down_included)?;
124123

125124
// Compute loss
126-
let loss_tensor = self
127-
.loss_fn
128-
.compute(output.inner(), target.inner())
129-
.map_err(KohoError::Candle)?;
125+
let loss_tensor = self.loss_fn.compute(output.inner(), target.inner())?;
130126

131127
let loss_val = loss_tensor.to_scalar::<f32>().unwrap_or(f32::NAN);
132128
total_loss += loss_val;
133129

134130
// Backward pass
135-
let grads = loss_tensor.backward().map_err(KohoError::Candle)?;
131+
let grads = loss_tensor.backward()?;
136132

137133
// Optimizer step (in-place update of parameters)
138-
optimizer
139-
.step(&grads, self.parameters_mut())
140-
.map_err(KohoError::Candle)?;
134+
optimizer.step(&grads, self.parameters_mut())?;
141135
}
142136

143137
let avg_loss = total_loss / (data.len() as f32);
@@ -161,8 +155,8 @@ impl SheafNN {
161155
let params = self.parameters();
162156
println!("Total parameters: {}", params.len());
163157
for (i, param) in params.iter().enumerate() {
164-
let param_data = param.as_tensor().flatten_all().map_err(KohoError::Candle)?;
165-
let param_vec = param_data.to_vec1::<f32>().map_err(KohoError::Candle)?;
158+
let param_data = param.as_tensor().flatten_all()?;
159+
let param_vec = param_data.to_vec1::<f32>()?;
166160
println!(
167161
"Parameter {i}: shape={:?}, first_few_values={:?}",
168162
param.shape(),
@@ -172,8 +166,7 @@ impl SheafNN {
172166

173167
// Create the optimizer
174168
let mut optimizer =
175-
create_optimizer(optimizer_kind, self.parameters_mut(), lr, optimizer_params)
176-
.map_err(KohoError::Candle)?;
169+
create_optimizer(optimizer_kind, self.parameters_mut(), lr, optimizer_params)?;
177170

178171
let mut metrics = TrainingMetrics::new(epochs);
179172

@@ -184,25 +177,22 @@ impl SheafNN {
184177
println!("\nEpoch {epoch}, Batch {batch_idx}");
185178

186179
// Print input/target info
187-
let input_data = input.inner().flatten_all().map_err(KohoError::Candle)?;
188-
let target_data = target.inner().flatten_all().map_err(KohoError::Candle)?;
189-
let input_vec = input_data.to_vec1::<f32>().map_err(KohoError::Candle)?;
190-
let target_vec = target_data.to_vec1::<f32>().map_err(KohoError::Candle)?;
180+
let input_data = input.inner().flatten_all()?;
181+
let target_data = target.inner().flatten_all()?;
182+
let input_vec = input_data.to_vec1::<f32>()?;
183+
let target_vec = target_data.to_vec1::<f32>()?;
191184

192185
println!("Input: {input_vec:?}");
193186
println!("Target: {target_vec:?}");
194187

195188
// Forward pass
196189
let output = self.forward(input.clone(), down_included)?;
197-
let output_data = output.inner().flatten_all().map_err(KohoError::Candle)?;
198-
let output_vec = output_data.to_vec1::<f32>().map_err(KohoError::Candle)?;
190+
let output_data = output.inner().flatten_all()?;
191+
let output_vec = output_data.to_vec1::<f32>()?;
199192
println!("Output: {output_vec:?}");
200193

201194
// Compute loss
202-
let loss_tensor = self
203-
.loss_fn
204-
.compute(output.inner(), target.inner())
205-
.map_err(KohoError::Candle)?;
195+
let loss_tensor = self.loss_fn.compute(output.inner(), target.inner())?;
206196

207197
let loss_val = loss_tensor.to_scalar::<f32>().unwrap_or(f32::NAN);
208198
total_loss += loss_val;
@@ -214,15 +204,15 @@ impl SheafNN {
214204

215205
// Backward pass
216206
println!("Computing gradients...");
217-
let grads = loss_tensor.backward().map_err(KohoError::Candle)?;
207+
let grads = loss_tensor.backward()?;
218208

219209
// Check gradients
220210
let params_mut = self.parameters_mut();
221211
println!("Checking gradients for {} parameters:", params_mut.len());
222212
for (i, param) in params_mut.iter().enumerate() {
223213
if let Some(grad) = grads.get(param) {
224-
let grad_data = grad.flatten_all().map_err(KohoError::Candle)?;
225-
let grad_vec = grad_data.to_vec1::<f32>().map_err(KohoError::Candle)?;
214+
let grad_data = grad.flatten_all()?;
215+
let grad_vec = grad_data.to_vec1::<f32>()?;
226216
let grad_norm = grad_vec.iter().map(|x| x * x).sum::<f32>().sqrt();
227217
println!(
228218
" Param {i}: grad_norm={grad_norm}, first_few_grads={:?}",
@@ -247,9 +237,7 @@ impl SheafNN {
247237
})
248238
.collect();
249239

250-
optimizer
251-
.step(&grads, self.parameters_mut())
252-
.map_err(KohoError::Candle)?;
240+
optimizer.step(&grads, self.parameters_mut())?;
253241

254242
let params_after: Vec<_> = self
255243
.parameters_mut()
@@ -382,8 +370,7 @@ mod integration_tests {
382370
let input = sheaf.get_k_cochain(0)?;
383371

384372
let target_data = vec![0.8f32, 0.6f32, 0.4f32];
385-
let target = Matrix::from_slice(&target_data, 1, 3, Device::Cpu, DType::F32)
386-
.map_err(KohoError::Candle)?;
373+
let target = Matrix::from_slice(&target_data, 1, 3, Device::Cpu, DType::F32)?;
387374

388375
let training_data = vec![(input, target)];
389376
let mut network = SheafNN::init(0, false, LossKind::MSE, sheaf);
@@ -408,13 +395,8 @@ mod integration_tests {
408395
let output = network.forward(initial_input, false)?;
409396

410397
// The output should be different from input (diffusion occurred)
411-
let input_vals = network
412-
.sheaf
413-
.get_k_cochain(0)?
414-
.inner()
415-
.to_vec2::<f32>()
416-
.map_err(KohoError::Candle)?;
417-
let output_vals = output.inner().to_vec2::<f32>().map_err(KohoError::Candle)?;
398+
let input_vals = network.sheaf.get_k_cochain(0)?.inner().to_vec2::<f32>()?;
399+
let output_vals = output.inner().to_vec2::<f32>()?;
418400

419401
println!("Input: {input_vals:?}");
420402
println!("Output: {output_vals:?}");
@@ -439,8 +421,7 @@ mod integration_tests {
439421
println!("got edges");
440422

441423
let target_data = vec![0.5f32, 0.3f32, 0.7f32];
442-
let target = Matrix::from_slice(&target_data, 1, 3, Device::Cpu, DType::F32)
443-
.map_err(KohoError::Candle)?;
424+
let target = Matrix::from_slice(&target_data, 1, 3, Device::Cpu, DType::F32)?;
444425

445426
let training_data = vec![(input, target)];
446427

@@ -473,8 +454,7 @@ mod integration_tests {
473454
let input = sheaf_learned.get_k_cochain(0)?;
474455

475456
let target_data = vec![0.8f32, 0.6f32, 0.4f32];
476-
let target = Matrix::from_slice(&target_data, 1, 3, Device::Cpu, DType::F32)
477-
.map_err(KohoError::Candle)?;
457+
let target = Matrix::from_slice(&target_data, 1, 3, Device::Cpu, DType::F32)?;
478458
let training_data = vec![(input.clone(), target.clone())];
479459

480460
// train learned network
@@ -523,8 +503,7 @@ mod integration_tests {
523503
let input = sheaf.get_k_cochain(0)?;
524504

525505
let target_data = vec![0.9f32, 0.8f32, 0.7f32];
526-
let target = Matrix::from_slice(&target_data, 1, 3, Device::Cpu, DType::F32)
527-
.map_err(KohoError::Candle)?;
506+
let target = Matrix::from_slice(&target_data, 1, 3, Device::Cpu, DType::F32)?;
528507
let training_data = vec![(input, target)];
529508

530509
let mut network = SheafNN::init(0, false, LossKind::MSE, sheaf);

0 commit comments

Comments
 (0)