Skip to content

Commit 739564d

Browse files
authored
fix: gradient tracking
lot more integration tests and fixed the candle backend's gradient tracking
2 parents 0b129c2 + 0adca59 commit 739564d

File tree

6 files changed

+1202
-88
lines changed

6 files changed

+1202
-88
lines changed

src/lib.rs

Lines changed: 377 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,152 @@ impl SheafNN {
145145
}
146146
Ok(metrics)
147147
}
148+
149+
pub fn train_debug(
150+
&mut self,
151+
data: &[(Matrix, Matrix)],
152+
epochs: usize,
153+
down_included: bool,
154+
optimizer_kind: OptimKind,
155+
lr: f64,
156+
optimizer_params: OptimizerParams,
157+
) -> Result<TrainingMetrics, KohoError> {
158+
println!("=== Training Debug Info ===");
159+
160+
// Check if we have any parameters to optimize
161+
let params = self.parameters();
162+
println!("Total parameters: {}", params.len());
163+
for (i, param) in params.iter().enumerate() {
164+
let param_data = param.as_tensor().flatten_all().map_err(KohoError::Candle)?;
165+
let param_vec = param_data.to_vec1::<f32>().map_err(KohoError::Candle)?;
166+
println!(
167+
"Parameter {i}: shape={:?}, first_few_values={:?}",
168+
param.shape(),
169+
&param_vec[..param_vec.len().min(5)]
170+
);
171+
}
172+
173+
// Create the optimizer
174+
let mut optimizer =
175+
create_optimizer(optimizer_kind, self.parameters_mut(), lr, optimizer_params)
176+
.map_err(KohoError::Candle)?;
177+
178+
let mut metrics = TrainingMetrics::new(epochs);
179+
180+
for epoch in 1..=epochs {
181+
let mut total_loss = 0.0_f32;
182+
183+
for (batch_idx, (input, target)) in data.iter().enumerate() {
184+
println!("\nEpoch {epoch}, Batch {batch_idx}");
185+
186+
// Print input/target info
187+
let input_data = input.inner().flatten_all().map_err(KohoError::Candle)?;
188+
let target_data = target.inner().flatten_all().map_err(KohoError::Candle)?;
189+
let input_vec = input_data.to_vec1::<f32>().map_err(KohoError::Candle)?;
190+
let target_vec = target_data.to_vec1::<f32>().map_err(KohoError::Candle)?;
191+
192+
println!("Input: {input_vec:?}");
193+
println!("Target: {target_vec:?}");
194+
195+
// Forward pass
196+
let output = self.forward(input.clone(), down_included)?;
197+
let output_data = output.inner().flatten_all().map_err(KohoError::Candle)?;
198+
let output_vec = output_data.to_vec1::<f32>().map_err(KohoError::Candle)?;
199+
println!("Output: {output_vec:?}");
200+
201+
// Compute loss
202+
let loss_tensor = self
203+
.loss_fn
204+
.compute(output.inner(), target.inner())
205+
.map_err(KohoError::Candle)?;
206+
207+
let loss_val = loss_tensor.to_scalar::<f32>().unwrap_or(f32::NAN);
208+
total_loss += loss_val;
209+
println!("Loss: {loss_val}");
210+
211+
// Check if loss tensor requires grad
212+
println!("Loss tensor shape: {:?}", loss_tensor.shape());
213+
println!("Loss tensor dtype: {:?}", loss_tensor.dtype());
214+
215+
// Backward pass
216+
println!("Computing gradients...");
217+
let grads = loss_tensor.backward().map_err(KohoError::Candle)?;
218+
219+
// Check gradients
220+
let params_mut = self.parameters_mut();
221+
println!("Checking gradients for {} parameters:", params_mut.len());
222+
for (i, param) in params_mut.iter().enumerate() {
223+
if let Some(grad) = grads.get(param) {
224+
let grad_data = grad.flatten_all().map_err(KohoError::Candle)?;
225+
let grad_vec = grad_data.to_vec1::<f32>().map_err(KohoError::Candle)?;
226+
let grad_norm = grad_vec.iter().map(|x| x * x).sum::<f32>().sqrt();
227+
println!(
228+
" Param {i}: grad_norm={grad_norm}, first_few_grads={:?}",
229+
&grad_vec[..grad_vec.len().min(3)]
230+
);
231+
} else {
232+
println!(" Param {i}: NO GRADIENT FOUND");
233+
}
234+
}
235+
236+
// Optimizer step (in-place update of parameters)
237+
println!("Applying optimizer step...");
238+
let params_before: Vec<_> = self
239+
.parameters_mut()
240+
.iter()
241+
.map(|p| {
242+
p.as_tensor()
243+
.flatten_all()
244+
.unwrap()
245+
.to_vec1::<f32>()
246+
.unwrap()
247+
})
248+
.collect();
249+
250+
optimizer
251+
.step(&grads, self.parameters_mut())
252+
.map_err(KohoError::Candle)?;
253+
254+
let params_after: Vec<_> = self
255+
.parameters_mut()
256+
.iter()
257+
.map(|p| {
258+
p.as_tensor()
259+
.flatten_all()
260+
.unwrap()
261+
.to_vec1::<f32>()
262+
.unwrap()
263+
})
264+
.collect();
265+
266+
// Check if parameters actually changed
267+
for (i, (before, after)) in
268+
params_before.iter().zip(params_after.iter()).enumerate()
269+
{
270+
let diff_norm: f32 = before
271+
.iter()
272+
.zip(after.iter())
273+
.map(|(b, a)| (b - a).powi(2))
274+
.sum::<f32>()
275+
.sqrt();
276+
println!(" Param {i} change norm: {diff_norm}");
277+
}
278+
279+
if epoch <= 3 {
280+
// Only print detailed info for first few epochs
281+
println!("--- End batch {batch_idx} ---");
282+
}
283+
}
284+
285+
let avg_loss = total_loss / (data.len() as f32);
286+
metrics.push(EpochMetrics::new(epoch, avg_loss));
287+
288+
if epoch <= 10 || epoch % 10 == 0 {
289+
println!("Epoch {epoch}: avg_loss = {avg_loss}");
290+
}
291+
}
292+
Ok(metrics)
293+
}
148294
}
149295

150296
impl Parameterized for SheafNN {
@@ -177,3 +323,234 @@ impl Parameterized for SheafNN {
177323
out
178324
}
179325
}
326+
327+
#[cfg(test)]
328+
mod integration_tests {
329+
use super::*;
330+
use crate::{
331+
math::{
332+
cell::Cell,
333+
sheaf::{CellularSheaf, Section},
334+
tensors::Matrix,
335+
},
336+
nn::{
337+
activate::Activations,
338+
diffuse::DiffusionLayer,
339+
loss::LossKind,
340+
optim::{OptimKind, OptimizerParams},
341+
},
342+
};
343+
use candle_core::{DType, Device};
344+
345+
/// Creates a triangle (2-cell) with 3 vertices and 3 edges
346+
fn create_triangle_sheaf() -> Result<CellularSheaf, KohoError> {
347+
let mut sheaf = CellularSheaf::init(DType::F32, Device::Cpu, true);
348+
349+
// Create 3 vertices (0-cells) with initial data
350+
let v0_data = Section::new(&[1.0f32], 1, Device::Cpu, DType::F32)?;
351+
let (_, v0_idx) = sheaf.attach(Cell::new(0), v0_data, None, None)?;
352+
353+
let v1_data = Section::new(&[0.0f32], 1, Device::Cpu, DType::F32)?;
354+
let (_, v1_idx) = sheaf.attach(Cell::new(0), v1_data, None, None)?;
355+
356+
let v2_data = Section::new(&[0.0f32], 1, Device::Cpu, DType::F32)?;
357+
let (_, v2_idx) = sheaf.attach(Cell::new(0), v2_data, None, None)?;
358+
359+
// Create 3 edges (1-cells)
360+
let e0_data = Section::new(&[0.5f32], 1, Device::Cpu, DType::F32)?;
361+
let (_, e0_idx) = sheaf.attach(Cell::new(1), e0_data, None, Some(&[v0_idx, v1_idx]))?;
362+
363+
let e1_data = Section::new(&[0.5f32], 1, Device::Cpu, DType::F32)?;
364+
let (_, e1_idx) = sheaf.attach(Cell::new(1), e1_data, None, Some(&[v1_idx, v2_idx]))?;
365+
366+
let e2_data = Section::new(&[0.5f32], 1, Device::Cpu, DType::F32)?;
367+
let (_, e2_idx) = sheaf.attach(Cell::new(1), e2_data, None, Some(&[v2_idx, v0_idx]))?;
368+
369+
// Create 1 triangle face (2-cell)
370+
let f0_data = Section::new(&[0.0f32], 1, Device::Cpu, DType::F32)?;
371+
let (_, _f0_idx) =
372+
sheaf.attach(Cell::new(2), f0_data, None, Some(&[e0_idx, e1_idx, e2_idx]))?;
373+
374+
sheaf.generate_initial_restrictions(0.1)?;
375+
println!("uppers: {:?}", sheaf.cells.cells[0][0].upper);
376+
Ok(sheaf)
377+
}
378+
379+
#[test]
380+
fn test_triangle_diffusion_learning() -> Result<(), KohoError> {
381+
let sheaf = create_triangle_sheaf()?;
382+
let input = sheaf.get_k_cochain(0)?;
383+
384+
let target_data = vec![0.8f32, 0.6f32, 0.4f32];
385+
let target = Matrix::from_slice(&target_data, 1, 3, Device::Cpu, DType::F32)
386+
.map_err(KohoError::Candle)?;
387+
388+
let training_data = vec![(input, target)];
389+
let mut network = SheafNN::init(0, false, LossKind::MSE, sheaf);
390+
391+
let diffusion_layer = DiffusionLayer::new(0, Activations::Linear, &network.sheaf)?;
392+
network.sequential(vec![diffusion_layer]);
393+
394+
let metrics = network.train_debug(
395+
&training_data,
396+
100,
397+
false,
398+
OptimKind::Adam,
399+
0.01,
400+
OptimizerParams::Else,
401+
)?;
402+
assert!(
403+
metrics.final_loss < metrics.epochs[0].loss,
404+
"Training should reduce loss over time"
405+
);
406+
407+
let initial_input = network.sheaf.get_k_cochain(0)?;
408+
let output = network.forward(initial_input, false)?;
409+
410+
// The output should be different from input (diffusion occurred)
411+
let input_vals = network
412+
.sheaf
413+
.get_k_cochain(0)?
414+
.inner()
415+
.to_vec2::<f32>()
416+
.map_err(KohoError::Candle)?;
417+
let output_vals = output.inner().to_vec2::<f32>().map_err(KohoError::Candle)?;
418+
419+
println!("Input: {input_vals:?}");
420+
println!("Output: {output_vals:?}");
421+
println!("Final loss: {}", metrics.final_loss);
422+
423+
assert_eq!(output_vals.len(), 1, "Output should have 1 feature");
424+
assert_eq!(
425+
output_vals[0].len(),
426+
3,
427+
"Each vertex should have 3 vertices"
428+
);
429+
430+
Ok(())
431+
}
432+
433+
#[test]
434+
fn test_edge_diffusion_learning() -> Result<(), KohoError> {
435+
let sheaf = create_triangle_sheaf()?;
436+
437+
// Test diffusion on edges (1-cells)
438+
let input = sheaf.get_k_cochain(1)?;
439+
println!("got edges");
440+
441+
let target_data = vec![0.5f32, 0.3f32, 0.7f32];
442+
let target = Matrix::from_slice(&target_data, 1, 3, Device::Cpu, DType::F32)
443+
.map_err(KohoError::Candle)?;
444+
445+
let training_data = vec![(input, target)];
446+
447+
let mut network = SheafNN::init(1, true, LossKind::MSE, sheaf); // down_included = true
448+
let diffusion_layer = DiffusionLayer::new(1, Activations::Tanh, &network.sheaf)?;
449+
network.sequential(vec![diffusion_layer]);
450+
451+
let metrics = network.train_debug(
452+
&training_data,
453+
200,
454+
true,
455+
OptimKind::Adam,
456+
0.2,
457+
OptimizerParams::Else,
458+
)?;
459+
println!("Edge diffusion final loss: {}", metrics.final_loss);
460+
461+
assert!(metrics.final_loss < 1.0, "Loss should be reasonable");
462+
463+
Ok(())
464+
}
465+
466+
#[test]
467+
fn test_learned_vs_fixed_restrictions() -> Result<(), KohoError> {
468+
let sheaf_learned = create_triangle_sheaf()?;
469+
assert!(sheaf_learned.learned, "Sheaf should have learned=true");
470+
471+
let mut sheaf_fixed = create_triangle_sheaf()?;
472+
sheaf_fixed.learned = false;
473+
let input = sheaf_learned.get_k_cochain(0)?;
474+
475+
let target_data = vec![0.8f32, 0.6f32, 0.4f32];
476+
let target = Matrix::from_slice(&target_data, 1, 3, Device::Cpu, DType::F32)
477+
.map_err(KohoError::Candle)?;
478+
let training_data = vec![(input.clone(), target.clone())];
479+
480+
// train learned network
481+
let mut network_learned = SheafNN::init(0, false, LossKind::MSE, sheaf_learned);
482+
let layer_learned = DiffusionLayer::new(0, Activations::Linear, &network_learned.sheaf)?;
483+
network_learned.sequential(vec![layer_learned]);
484+
485+
let metrics_learned = network_learned.train(
486+
&training_data,
487+
50,
488+
false,
489+
OptimKind::Adam,
490+
0.01,
491+
OptimizerParams::Else,
492+
)?;
493+
494+
// train fixed network
495+
let mut network_fixed = SheafNN::init(0, false, LossKind::MSE, sheaf_fixed);
496+
let layer_fixed = DiffusionLayer::new(0, Activations::Linear, &network_fixed.sheaf)?;
497+
network_fixed.sequential(vec![layer_fixed]);
498+
499+
let metrics_fixed = network_fixed.train_debug(
500+
&training_data,
501+
50,
502+
false,
503+
OptimKind::Adam,
504+
0.01,
505+
OptimizerParams::Else,
506+
)?;
507+
508+
println!(
509+
"Learned restrictions final loss: {}",
510+
metrics_learned.final_loss
511+
);
512+
println!(
513+
"Fixed restrictions final loss: {}",
514+
metrics_fixed.final_loss
515+
);
516+
517+
Ok(())
518+
}
519+
520+
#[test]
521+
fn test_multiple_diffusion_layers() -> Result<(), KohoError> {
522+
let sheaf = create_triangle_sheaf()?;
523+
let input = sheaf.get_k_cochain(0)?;
524+
525+
let target_data = vec![0.9f32, 0.8f32, 0.7f32];
526+
let target = Matrix::from_slice(&target_data, 1, 3, Device::Cpu, DType::F32)
527+
.map_err(KohoError::Candle)?;
528+
let training_data = vec![(input, target)];
529+
530+
let mut network = SheafNN::init(0, false, LossKind::MSE, sheaf);
531+
532+
let layer1 = DiffusionLayer::new(0, Activations::Softmax, &network.sheaf)?;
533+
let layer2 = DiffusionLayer::new(0, Activations::Tanh, &network.sheaf)?;
534+
let layer3 = DiffusionLayer::new(0, Activations::Sigmoid, &network.sheaf)?;
535+
536+
network.sequential(vec![layer1, layer2, layer3]);
537+
538+
let metrics = network.train_debug(
539+
&training_data,
540+
175,
541+
false,
542+
OptimKind::Adam,
543+
0.15,
544+
OptimizerParams::Else,
545+
)?;
546+
547+
println!("Multi-layer network final loss: {}", metrics.final_loss);
548+
549+
let test_input = network.sheaf.get_k_cochain(0)?;
550+
let output = network.forward(test_input, false)?;
551+
552+
assert_eq!(output.rows(), 1, "Output should have 3 vertices");
553+
554+
Ok(())
555+
}
556+
}

0 commit comments

Comments
 (0)