@@ -110,8 +110,7 @@ impl SheafNN {
110110 ) -> Result < TrainingMetrics , KohoError > {
111111 // Create the optimizer
112112 let mut optimizer =
113- create_optimizer ( optimizer_kind, self . parameters_mut ( ) , lr, optimizer_params)
114- . map_err ( KohoError :: Candle ) ?;
113+ create_optimizer ( optimizer_kind, self . parameters_mut ( ) , lr, optimizer_params) ?;
115114
116115 let mut metrics = TrainingMetrics :: new ( epochs) ;
117116
@@ -123,21 +122,16 @@ impl SheafNN {
123122 let output = self . forward ( input. clone ( ) , down_included) ?;
124123
125124 // Compute loss
126- let loss_tensor = self
127- . loss_fn
128- . compute ( output. inner ( ) , target. inner ( ) )
129- . map_err ( KohoError :: Candle ) ?;
125+ let loss_tensor = self . loss_fn . compute ( output. inner ( ) , target. inner ( ) ) ?;
130126
131127 let loss_val = loss_tensor. to_scalar :: < f32 > ( ) . unwrap_or ( f32:: NAN ) ;
132128 total_loss += loss_val;
133129
134130 // Backward pass
135- let grads = loss_tensor. backward ( ) . map_err ( KohoError :: Candle ) ?;
131+ let grads = loss_tensor. backward ( ) ?;
136132
137133 // Optimizer step (in-place update of parameters)
138- optimizer
139- . step ( & grads, self . parameters_mut ( ) )
140- . map_err ( KohoError :: Candle ) ?;
134+ optimizer. step ( & grads, self . parameters_mut ( ) ) ?;
141135 }
142136
143137 let avg_loss = total_loss / ( data. len ( ) as f32 ) ;
@@ -161,8 +155,8 @@ impl SheafNN {
161155 let params = self . parameters ( ) ;
162156 println ! ( "Total parameters: {}" , params. len( ) ) ;
163157 for ( i, param) in params. iter ( ) . enumerate ( ) {
164- let param_data = param. as_tensor ( ) . flatten_all ( ) . map_err ( KohoError :: Candle ) ?;
165- let param_vec = param_data. to_vec1 :: < f32 > ( ) . map_err ( KohoError :: Candle ) ?;
158+ let param_data = param. as_tensor ( ) . flatten_all ( ) ?;
159+ let param_vec = param_data. to_vec1 :: < f32 > ( ) ?;
166160 println ! (
167161 "Parameter {i}: shape={:?}, first_few_values={:?}" ,
168162 param. shape( ) ,
@@ -172,8 +166,7 @@ impl SheafNN {
172166
173167 // Create the optimizer
174168 let mut optimizer =
175- create_optimizer ( optimizer_kind, self . parameters_mut ( ) , lr, optimizer_params)
176- . map_err ( KohoError :: Candle ) ?;
169+ create_optimizer ( optimizer_kind, self . parameters_mut ( ) , lr, optimizer_params) ?;
177170
178171 let mut metrics = TrainingMetrics :: new ( epochs) ;
179172
@@ -184,25 +177,22 @@ impl SheafNN {
184177 println ! ( "\n Epoch {epoch}, Batch {batch_idx}" ) ;
185178
186179 // Print input/target info
187- let input_data = input. inner ( ) . flatten_all ( ) . map_err ( KohoError :: Candle ) ?;
188- let target_data = target. inner ( ) . flatten_all ( ) . map_err ( KohoError :: Candle ) ?;
189- let input_vec = input_data. to_vec1 :: < f32 > ( ) . map_err ( KohoError :: Candle ) ?;
190- let target_vec = target_data. to_vec1 :: < f32 > ( ) . map_err ( KohoError :: Candle ) ?;
180+ let input_data = input. inner ( ) . flatten_all ( ) ?;
181+ let target_data = target. inner ( ) . flatten_all ( ) ?;
182+ let input_vec = input_data. to_vec1 :: < f32 > ( ) ?;
183+ let target_vec = target_data. to_vec1 :: < f32 > ( ) ?;
191184
192185 println ! ( "Input: {input_vec:?}" ) ;
193186 println ! ( "Target: {target_vec:?}" ) ;
194187
195188 // Forward pass
196189 let output = self . forward ( input. clone ( ) , down_included) ?;
197- let output_data = output. inner ( ) . flatten_all ( ) . map_err ( KohoError :: Candle ) ?;
198- let output_vec = output_data. to_vec1 :: < f32 > ( ) . map_err ( KohoError :: Candle ) ?;
190+ let output_data = output. inner ( ) . flatten_all ( ) ?;
191+ let output_vec = output_data. to_vec1 :: < f32 > ( ) ?;
199192 println ! ( "Output: {output_vec:?}" ) ;
200193
201194 // Compute loss
202- let loss_tensor = self
203- . loss_fn
204- . compute ( output. inner ( ) , target. inner ( ) )
205- . map_err ( KohoError :: Candle ) ?;
195+ let loss_tensor = self . loss_fn . compute ( output. inner ( ) , target. inner ( ) ) ?;
206196
207197 let loss_val = loss_tensor. to_scalar :: < f32 > ( ) . unwrap_or ( f32:: NAN ) ;
208198 total_loss += loss_val;
@@ -214,15 +204,15 @@ impl SheafNN {
214204
215205 // Backward pass
216206 println ! ( "Computing gradients..." ) ;
217- let grads = loss_tensor. backward ( ) . map_err ( KohoError :: Candle ) ?;
207+ let grads = loss_tensor. backward ( ) ?;
218208
219209 // Check gradients
220210 let params_mut = self . parameters_mut ( ) ;
221211 println ! ( "Checking gradients for {} parameters:" , params_mut. len( ) ) ;
222212 for ( i, param) in params_mut. iter ( ) . enumerate ( ) {
223213 if let Some ( grad) = grads. get ( param) {
224- let grad_data = grad. flatten_all ( ) . map_err ( KohoError :: Candle ) ?;
225- let grad_vec = grad_data. to_vec1 :: < f32 > ( ) . map_err ( KohoError :: Candle ) ?;
214+ let grad_data = grad. flatten_all ( ) ?;
215+ let grad_vec = grad_data. to_vec1 :: < f32 > ( ) ?;
226216 let grad_norm = grad_vec. iter ( ) . map ( |x| x * x) . sum :: < f32 > ( ) . sqrt ( ) ;
227217 println ! (
228218 " Param {i}: grad_norm={grad_norm}, first_few_grads={:?}" ,
@@ -247,9 +237,7 @@ impl SheafNN {
247237 } )
248238 . collect ( ) ;
249239
250- optimizer
251- . step ( & grads, self . parameters_mut ( ) )
252- . map_err ( KohoError :: Candle ) ?;
240+ optimizer. step ( & grads, self . parameters_mut ( ) ) ?;
253241
254242 let params_after: Vec < _ > = self
255243 . parameters_mut ( )
@@ -382,8 +370,7 @@ mod integration_tests {
382370 let input = sheaf. get_k_cochain ( 0 ) ?;
383371
384372 let target_data = vec ! [ 0.8f32 , 0.6f32 , 0.4f32 ] ;
385- let target = Matrix :: from_slice ( & target_data, 1 , 3 , Device :: Cpu , DType :: F32 )
386- . map_err ( KohoError :: Candle ) ?;
373+ let target = Matrix :: from_slice ( & target_data, 1 , 3 , Device :: Cpu , DType :: F32 ) ?;
387374
388375 let training_data = vec ! [ ( input, target) ] ;
389376 let mut network = SheafNN :: init ( 0 , false , LossKind :: MSE , sheaf) ;
@@ -408,13 +395,8 @@ mod integration_tests {
408395 let output = network. forward ( initial_input, false ) ?;
409396
410397 // The output should be different from input (diffusion occurred)
411- let input_vals = network
412- . sheaf
413- . get_k_cochain ( 0 ) ?
414- . inner ( )
415- . to_vec2 :: < f32 > ( )
416- . map_err ( KohoError :: Candle ) ?;
417- let output_vals = output. inner ( ) . to_vec2 :: < f32 > ( ) . map_err ( KohoError :: Candle ) ?;
398+ let input_vals = network. sheaf . get_k_cochain ( 0 ) ?. inner ( ) . to_vec2 :: < f32 > ( ) ?;
399+ let output_vals = output. inner ( ) . to_vec2 :: < f32 > ( ) ?;
418400
419401 println ! ( "Input: {input_vals:?}" ) ;
420402 println ! ( "Output: {output_vals:?}" ) ;
@@ -439,8 +421,7 @@ mod integration_tests {
439421 println ! ( "got edges" ) ;
440422
441423 let target_data = vec ! [ 0.5f32 , 0.3f32 , 0.7f32 ] ;
442- let target = Matrix :: from_slice ( & target_data, 1 , 3 , Device :: Cpu , DType :: F32 )
443- . map_err ( KohoError :: Candle ) ?;
424+ let target = Matrix :: from_slice ( & target_data, 1 , 3 , Device :: Cpu , DType :: F32 ) ?;
444425
445426 let training_data = vec ! [ ( input, target) ] ;
446427
@@ -473,8 +454,7 @@ mod integration_tests {
473454 let input = sheaf_learned. get_k_cochain ( 0 ) ?;
474455
475456 let target_data = vec ! [ 0.8f32 , 0.6f32 , 0.4f32 ] ;
476- let target = Matrix :: from_slice ( & target_data, 1 , 3 , Device :: Cpu , DType :: F32 )
477- . map_err ( KohoError :: Candle ) ?;
457+ let target = Matrix :: from_slice ( & target_data, 1 , 3 , Device :: Cpu , DType :: F32 ) ?;
478458 let training_data = vec ! [ ( input. clone( ) , target. clone( ) ) ] ;
479459
480460 // train learned network
@@ -523,8 +503,7 @@ mod integration_tests {
523503 let input = sheaf. get_k_cochain ( 0 ) ?;
524504
525505 let target_data = vec ! [ 0.9f32 , 0.8f32 , 0.7f32 ] ;
526- let target = Matrix :: from_slice ( & target_data, 1 , 3 , Device :: Cpu , DType :: F32 )
527- . map_err ( KohoError :: Candle ) ?;
506+ let target = Matrix :: from_slice ( & target_data, 1 , 3 , Device :: Cpu , DType :: F32 ) ?;
528507 let training_data = vec ! [ ( input, target) ] ;
529508
530509 let mut network = SheafNN :: init ( 0 , false , LossKind :: MSE , sheaf) ;
0 commit comments