@@ -145,6 +145,152 @@ impl SheafNN {
145145 }
146146 Ok ( metrics)
147147 }
148+
149+ pub fn train_debug (
150+ & mut self ,
151+ data : & [ ( Matrix , Matrix ) ] ,
152+ epochs : usize ,
153+ down_included : bool ,
154+ optimizer_kind : OptimKind ,
155+ lr : f64 ,
156+ optimizer_params : OptimizerParams ,
157+ ) -> Result < TrainingMetrics , KohoError > {
158+ println ! ( "=== Training Debug Info ===" ) ;
159+
160+ // Check if we have any parameters to optimize
161+ let params = self . parameters ( ) ;
162+ println ! ( "Total parameters: {}" , params. len( ) ) ;
163+ for ( i, param) in params. iter ( ) . enumerate ( ) {
164+ let param_data = param. as_tensor ( ) . flatten_all ( ) . map_err ( KohoError :: Candle ) ?;
165+ let param_vec = param_data. to_vec1 :: < f32 > ( ) . map_err ( KohoError :: Candle ) ?;
166+ println ! (
167+ "Parameter {i}: shape={:?}, first_few_values={:?}" ,
168+ param. shape( ) ,
169+ & param_vec[ ..param_vec. len( ) . min( 5 ) ]
170+ ) ;
171+ }
172+
173+ // Create the optimizer
174+ let mut optimizer =
175+ create_optimizer ( optimizer_kind, self . parameters_mut ( ) , lr, optimizer_params)
176+ . map_err ( KohoError :: Candle ) ?;
177+
178+ let mut metrics = TrainingMetrics :: new ( epochs) ;
179+
180+ for epoch in 1 ..=epochs {
181+ let mut total_loss = 0.0_f32 ;
182+
183+ for ( batch_idx, ( input, target) ) in data. iter ( ) . enumerate ( ) {
184+ println ! ( "\n Epoch {epoch}, Batch {batch_idx}" ) ;
185+
186+ // Print input/target info
187+ let input_data = input. inner ( ) . flatten_all ( ) . map_err ( KohoError :: Candle ) ?;
188+ let target_data = target. inner ( ) . flatten_all ( ) . map_err ( KohoError :: Candle ) ?;
189+ let input_vec = input_data. to_vec1 :: < f32 > ( ) . map_err ( KohoError :: Candle ) ?;
190+ let target_vec = target_data. to_vec1 :: < f32 > ( ) . map_err ( KohoError :: Candle ) ?;
191+
192+ println ! ( "Input: {input_vec:?}" ) ;
193+ println ! ( "Target: {target_vec:?}" ) ;
194+
195+ // Forward pass
196+ let output = self . forward ( input. clone ( ) , down_included) ?;
197+ let output_data = output. inner ( ) . flatten_all ( ) . map_err ( KohoError :: Candle ) ?;
198+ let output_vec = output_data. to_vec1 :: < f32 > ( ) . map_err ( KohoError :: Candle ) ?;
199+ println ! ( "Output: {output_vec:?}" ) ;
200+
201+ // Compute loss
202+ let loss_tensor = self
203+ . loss_fn
204+ . compute ( output. inner ( ) , target. inner ( ) )
205+ . map_err ( KohoError :: Candle ) ?;
206+
207+ let loss_val = loss_tensor. to_scalar :: < f32 > ( ) . unwrap_or ( f32:: NAN ) ;
208+ total_loss += loss_val;
209+ println ! ( "Loss: {loss_val}" ) ;
210+
211+ // Check if loss tensor requires grad
212+ println ! ( "Loss tensor shape: {:?}" , loss_tensor. shape( ) ) ;
213+ println ! ( "Loss tensor dtype: {:?}" , loss_tensor. dtype( ) ) ;
214+
215+ // Backward pass
216+ println ! ( "Computing gradients..." ) ;
217+ let grads = loss_tensor. backward ( ) . map_err ( KohoError :: Candle ) ?;
218+
219+ // Check gradients
220+ let params_mut = self . parameters_mut ( ) ;
221+ println ! ( "Checking gradients for {} parameters:" , params_mut. len( ) ) ;
222+ for ( i, param) in params_mut. iter ( ) . enumerate ( ) {
223+ if let Some ( grad) = grads. get ( param) {
224+ let grad_data = grad. flatten_all ( ) . map_err ( KohoError :: Candle ) ?;
225+ let grad_vec = grad_data. to_vec1 :: < f32 > ( ) . map_err ( KohoError :: Candle ) ?;
226+ let grad_norm = grad_vec. iter ( ) . map ( |x| x * x) . sum :: < f32 > ( ) . sqrt ( ) ;
227+ println ! (
228+ " Param {i}: grad_norm={grad_norm}, first_few_grads={:?}" ,
229+ & grad_vec[ ..grad_vec. len( ) . min( 3 ) ]
230+ ) ;
231+ } else {
232+ println ! ( " Param {i}: NO GRADIENT FOUND" ) ;
233+ }
234+ }
235+
236+ // Optimizer step (in-place update of parameters)
237+ println ! ( "Applying optimizer step..." ) ;
238+ let params_before: Vec < _ > = self
239+ . parameters_mut ( )
240+ . iter ( )
241+ . map ( |p| {
242+ p. as_tensor ( )
243+ . flatten_all ( )
244+ . unwrap ( )
245+ . to_vec1 :: < f32 > ( )
246+ . unwrap ( )
247+ } )
248+ . collect ( ) ;
249+
250+ optimizer
251+ . step ( & grads, self . parameters_mut ( ) )
252+ . map_err ( KohoError :: Candle ) ?;
253+
254+ let params_after: Vec < _ > = self
255+ . parameters_mut ( )
256+ . iter ( )
257+ . map ( |p| {
258+ p. as_tensor ( )
259+ . flatten_all ( )
260+ . unwrap ( )
261+ . to_vec1 :: < f32 > ( )
262+ . unwrap ( )
263+ } )
264+ . collect ( ) ;
265+
266+ // Check if parameters actually changed
267+ for ( i, ( before, after) ) in
268+ params_before. iter ( ) . zip ( params_after. iter ( ) ) . enumerate ( )
269+ {
270+ let diff_norm: f32 = before
271+ . iter ( )
272+ . zip ( after. iter ( ) )
273+ . map ( |( b, a) | ( b - a) . powi ( 2 ) )
274+ . sum :: < f32 > ( )
275+ . sqrt ( ) ;
276+ println ! ( " Param {i} change norm: {diff_norm}" ) ;
277+ }
278+
279+ if epoch <= 3 {
280+ // Only print detailed info for first few epochs
281+ println ! ( "--- End batch {batch_idx} ---" ) ;
282+ }
283+ }
284+
285+ let avg_loss = total_loss / ( data. len ( ) as f32 ) ;
286+ metrics. push ( EpochMetrics :: new ( epoch, avg_loss) ) ;
287+
288+ if epoch <= 10 || epoch % 10 == 0 {
289+ println ! ( "Epoch {epoch}: avg_loss = {avg_loss}" ) ;
290+ }
291+ }
292+ Ok ( metrics)
293+ }
148294}
149295
150296impl Parameterized for SheafNN {
@@ -177,3 +323,234 @@ impl Parameterized for SheafNN {
177323 out
178324 }
179325}
326+
327+ #[ cfg( test) ]
328+ mod integration_tests {
329+ use super :: * ;
330+ use crate :: {
331+ math:: {
332+ cell:: Cell ,
333+ sheaf:: { CellularSheaf , Section } ,
334+ tensors:: Matrix ,
335+ } ,
336+ nn:: {
337+ activate:: Activations ,
338+ diffuse:: DiffusionLayer ,
339+ loss:: LossKind ,
340+ optim:: { OptimKind , OptimizerParams } ,
341+ } ,
342+ } ;
343+ use candle_core:: { DType , Device } ;
344+
345+ /// Creates a triangle (2-cell) with 3 vertices and 3 edges
346+ fn create_triangle_sheaf ( ) -> Result < CellularSheaf , KohoError > {
347+ let mut sheaf = CellularSheaf :: init ( DType :: F32 , Device :: Cpu , true ) ;
348+
349+ // Create 3 vertices (0-cells) with initial data
350+ let v0_data = Section :: new ( & [ 1.0f32 ] , 1 , Device :: Cpu , DType :: F32 ) ?;
351+ let ( _, v0_idx) = sheaf. attach ( Cell :: new ( 0 ) , v0_data, None , None ) ?;
352+
353+ let v1_data = Section :: new ( & [ 0.0f32 ] , 1 , Device :: Cpu , DType :: F32 ) ?;
354+ let ( _, v1_idx) = sheaf. attach ( Cell :: new ( 0 ) , v1_data, None , None ) ?;
355+
356+ let v2_data = Section :: new ( & [ 0.0f32 ] , 1 , Device :: Cpu , DType :: F32 ) ?;
357+ let ( _, v2_idx) = sheaf. attach ( Cell :: new ( 0 ) , v2_data, None , None ) ?;
358+
359+ // Create 3 edges (1-cells)
360+ let e0_data = Section :: new ( & [ 0.5f32 ] , 1 , Device :: Cpu , DType :: F32 ) ?;
361+ let ( _, e0_idx) = sheaf. attach ( Cell :: new ( 1 ) , e0_data, None , Some ( & [ v0_idx, v1_idx] ) ) ?;
362+
363+ let e1_data = Section :: new ( & [ 0.5f32 ] , 1 , Device :: Cpu , DType :: F32 ) ?;
364+ let ( _, e1_idx) = sheaf. attach ( Cell :: new ( 1 ) , e1_data, None , Some ( & [ v1_idx, v2_idx] ) ) ?;
365+
366+ let e2_data = Section :: new ( & [ 0.5f32 ] , 1 , Device :: Cpu , DType :: F32 ) ?;
367+ let ( _, e2_idx) = sheaf. attach ( Cell :: new ( 1 ) , e2_data, None , Some ( & [ v2_idx, v0_idx] ) ) ?;
368+
369+ // Create 1 triangle face (2-cell)
370+ let f0_data = Section :: new ( & [ 0.0f32 ] , 1 , Device :: Cpu , DType :: F32 ) ?;
371+ let ( _, _f0_idx) =
372+ sheaf. attach ( Cell :: new ( 2 ) , f0_data, None , Some ( & [ e0_idx, e1_idx, e2_idx] ) ) ?;
373+
374+ sheaf. generate_initial_restrictions ( 0.1 ) ?;
375+ println ! ( "uppers: {:?}" , sheaf. cells. cells[ 0 ] [ 0 ] . upper) ;
376+ Ok ( sheaf)
377+ }
378+
379+ #[ test]
380+ fn test_triangle_diffusion_learning ( ) -> Result < ( ) , KohoError > {
381+ let sheaf = create_triangle_sheaf ( ) ?;
382+ let input = sheaf. get_k_cochain ( 0 ) ?;
383+
384+ let target_data = vec ! [ 0.8f32 , 0.6f32 , 0.4f32 ] ;
385+ let target = Matrix :: from_slice ( & target_data, 1 , 3 , Device :: Cpu , DType :: F32 )
386+ . map_err ( KohoError :: Candle ) ?;
387+
388+ let training_data = vec ! [ ( input, target) ] ;
389+ let mut network = SheafNN :: init ( 0 , false , LossKind :: MSE , sheaf) ;
390+
391+ let diffusion_layer = DiffusionLayer :: new ( 0 , Activations :: Linear , & network. sheaf ) ?;
392+ network. sequential ( vec ! [ diffusion_layer] ) ;
393+
394+ let metrics = network. train_debug (
395+ & training_data,
396+ 100 ,
397+ false ,
398+ OptimKind :: Adam ,
399+ 0.01 ,
400+ OptimizerParams :: Else ,
401+ ) ?;
402+ assert ! (
403+ metrics. final_loss < metrics. epochs[ 0 ] . loss,
404+ "Training should reduce loss over time"
405+ ) ;
406+
407+ let initial_input = network. sheaf . get_k_cochain ( 0 ) ?;
408+ let output = network. forward ( initial_input, false ) ?;
409+
410+ // The output should be different from input (diffusion occurred)
411+ let input_vals = network
412+ . sheaf
413+ . get_k_cochain ( 0 ) ?
414+ . inner ( )
415+ . to_vec2 :: < f32 > ( )
416+ . map_err ( KohoError :: Candle ) ?;
417+ let output_vals = output. inner ( ) . to_vec2 :: < f32 > ( ) . map_err ( KohoError :: Candle ) ?;
418+
419+ println ! ( "Input: {input_vals:?}" ) ;
420+ println ! ( "Output: {output_vals:?}" ) ;
421+ println ! ( "Final loss: {}" , metrics. final_loss) ;
422+
423+ assert_eq ! ( output_vals. len( ) , 1 , "Output should have 1 feature" ) ;
424+ assert_eq ! (
425+ output_vals[ 0 ] . len( ) ,
426+ 3 ,
427+ "Each vertex should have 3 vertices"
428+ ) ;
429+
430+ Ok ( ( ) )
431+ }
432+
433+ #[ test]
434+ fn test_edge_diffusion_learning ( ) -> Result < ( ) , KohoError > {
435+ let sheaf = create_triangle_sheaf ( ) ?;
436+
437+ // Test diffusion on edges (1-cells)
438+ let input = sheaf. get_k_cochain ( 1 ) ?;
439+ println ! ( "got edges" ) ;
440+
441+ let target_data = vec ! [ 0.5f32 , 0.3f32 , 0.7f32 ] ;
442+ let target = Matrix :: from_slice ( & target_data, 1 , 3 , Device :: Cpu , DType :: F32 )
443+ . map_err ( KohoError :: Candle ) ?;
444+
445+ let training_data = vec ! [ ( input, target) ] ;
446+
447+ let mut network = SheafNN :: init ( 1 , true , LossKind :: MSE , sheaf) ; // down_included = true
448+ let diffusion_layer = DiffusionLayer :: new ( 1 , Activations :: Tanh , & network. sheaf ) ?;
449+ network. sequential ( vec ! [ diffusion_layer] ) ;
450+
451+ let metrics = network. train_debug (
452+ & training_data,
453+ 200 ,
454+ true ,
455+ OptimKind :: Adam ,
456+ 0.2 ,
457+ OptimizerParams :: Else ,
458+ ) ?;
459+ println ! ( "Edge diffusion final loss: {}" , metrics. final_loss) ;
460+
461+ assert ! ( metrics. final_loss < 1.0 , "Loss should be reasonable" ) ;
462+
463+ Ok ( ( ) )
464+ }
465+
466+ #[ test]
467+ fn test_learned_vs_fixed_restrictions ( ) -> Result < ( ) , KohoError > {
468+ let sheaf_learned = create_triangle_sheaf ( ) ?;
469+ assert ! ( sheaf_learned. learned, "Sheaf should have learned=true" ) ;
470+
471+ let mut sheaf_fixed = create_triangle_sheaf ( ) ?;
472+ sheaf_fixed. learned = false ;
473+ let input = sheaf_learned. get_k_cochain ( 0 ) ?;
474+
475+ let target_data = vec ! [ 0.8f32 , 0.6f32 , 0.4f32 ] ;
476+ let target = Matrix :: from_slice ( & target_data, 1 , 3 , Device :: Cpu , DType :: F32 )
477+ . map_err ( KohoError :: Candle ) ?;
478+ let training_data = vec ! [ ( input. clone( ) , target. clone( ) ) ] ;
479+
480+ // train learned network
481+ let mut network_learned = SheafNN :: init ( 0 , false , LossKind :: MSE , sheaf_learned) ;
482+ let layer_learned = DiffusionLayer :: new ( 0 , Activations :: Linear , & network_learned. sheaf ) ?;
483+ network_learned. sequential ( vec ! [ layer_learned] ) ;
484+
485+ let metrics_learned = network_learned. train (
486+ & training_data,
487+ 50 ,
488+ false ,
489+ OptimKind :: Adam ,
490+ 0.01 ,
491+ OptimizerParams :: Else ,
492+ ) ?;
493+
494+ // train fixed network
495+ let mut network_fixed = SheafNN :: init ( 0 , false , LossKind :: MSE , sheaf_fixed) ;
496+ let layer_fixed = DiffusionLayer :: new ( 0 , Activations :: Linear , & network_fixed. sheaf ) ?;
497+ network_fixed. sequential ( vec ! [ layer_fixed] ) ;
498+
499+ let metrics_fixed = network_fixed. train_debug (
500+ & training_data,
501+ 50 ,
502+ false ,
503+ OptimKind :: Adam ,
504+ 0.01 ,
505+ OptimizerParams :: Else ,
506+ ) ?;
507+
508+ println ! (
509+ "Learned restrictions final loss: {}" ,
510+ metrics_learned. final_loss
511+ ) ;
512+ println ! (
513+ "Fixed restrictions final loss: {}" ,
514+ metrics_fixed. final_loss
515+ ) ;
516+
517+ Ok ( ( ) )
518+ }
519+
520+ #[ test]
521+ fn test_multiple_diffusion_layers ( ) -> Result < ( ) , KohoError > {
522+ let sheaf = create_triangle_sheaf ( ) ?;
523+ let input = sheaf. get_k_cochain ( 0 ) ?;
524+
525+ let target_data = vec ! [ 0.9f32 , 0.8f32 , 0.7f32 ] ;
526+ let target = Matrix :: from_slice ( & target_data, 1 , 3 , Device :: Cpu , DType :: F32 )
527+ . map_err ( KohoError :: Candle ) ?;
528+ let training_data = vec ! [ ( input, target) ] ;
529+
530+ let mut network = SheafNN :: init ( 0 , false , LossKind :: MSE , sheaf) ;
531+
532+ let layer1 = DiffusionLayer :: new ( 0 , Activations :: Softmax , & network. sheaf ) ?;
533+ let layer2 = DiffusionLayer :: new ( 0 , Activations :: Tanh , & network. sheaf ) ?;
534+ let layer3 = DiffusionLayer :: new ( 0 , Activations :: Sigmoid , & network. sheaf ) ?;
535+
536+ network. sequential ( vec ! [ layer1, layer2, layer3] ) ;
537+
538+ let metrics = network. train_debug (
539+ & training_data,
540+ 175 ,
541+ false ,
542+ OptimKind :: Adam ,
543+ 0.15 ,
544+ OptimizerParams :: Else ,
545+ ) ?;
546+
547+ println ! ( "Multi-layer network final loss: {}" , metrics. final_loss) ;
548+
549+ let test_input = network. sheaf . get_k_cochain ( 0 ) ?;
550+ let output = network. forward ( test_input, false ) ?;
551+
552+ assert_eq ! ( output. rows( ) , 1 , "Output should have 3 vertices" ) ;
553+
554+ Ok ( ( ) )
555+ }
556+ }
0 commit comments