@@ -2281,6 +2281,114 @@ real ann_train_network(PNetwork pnet, PTensor inputs, PTensor outputs, int rows)
22812281 return loss ;
22822282}
22832283
2284+ //-----------------------------------------------
2285+ // Begin an online/incremental training session
2286+ //-----------------------------------------------
2287+ int ann_train_begin (PNetwork pnet )
2288+ {
2289+ if (!pnet )
2290+ return ERR_NULL_PTR ;
2291+
2292+ if (pnet -> layer_count <= 0 || !pnet -> layers )
2293+ return ERR_INVALID ;
2294+
2295+ // Enable training mode (for dropout)
2296+ pnet -> is_training = 1 ;
2297+
2298+ // Save base learning rate for schedulers
2299+ if (pnet -> base_learning_rate == (real )0.0 )
2300+ pnet -> base_learning_rate = pnet -> learning_rate ;
2301+
2302+ // Initialize weights only if not already set (e.g. loaded model)
2303+ init_weights (pnet );
2304+
2305+ // Ensure batch tensors are allocated for the configured batch size
2306+ if (ensure_batch_tensors (pnet , pnet -> batchSize ) != ERR_OK )
2307+ {
2308+ invoke_error_callback (ERR_ALLOC , "ann_train_begin" );
2309+ return ERR_ALLOC ;
2310+ }
2311+
2312+ return ERR_OK ;
2313+ }
2314+
2315+ //-----------------------------------------------
2316+ // Train one mini-batch step (online training)
2317+ //-----------------------------------------------
2318+ real ann_train_step (PNetwork pnet , const real * inputs , const real * targets , int batch_size )
2319+ {
2320+ if (!pnet || !inputs || !targets )
2321+ return (real )0.0 ;
2322+
2323+ if (batch_size <= 0 )
2324+ return (real )0.0 ;
2325+
2326+ int input_node_count = pnet -> layers [0 ].node_count ;
2327+ int output_node_count = pnet -> layers [pnet -> layer_count - 1 ].node_count ;
2328+
2329+ unsigned actual_batch_size = (unsigned )batch_size ;
2330+
2331+ // Reallocate batch tensors if batch size changed
2332+ if (pnet -> current_batch_size != actual_batch_size )
2333+ {
2334+ if (ensure_batch_tensors (pnet , actual_batch_size ) != ERR_OK )
2335+ {
2336+ invoke_error_callback (ERR_ALLOC , "ann_train_step" );
2337+ return (real )0.0 ;
2338+ }
2339+ }
2340+
2341+ // Allocate temporary batch target tensor
2342+ PTensor batch_targets = tensor_create (actual_batch_size , output_node_count );
2343+ if (!batch_targets )
2344+ {
2345+ invoke_error_callback (ERR_ALLOC , "ann_train_step" );
2346+ return (real )0.0 ;
2347+ }
2348+
2349+ // Zero gradients
2350+ for (int layer = 0 ; layer < pnet -> layer_count - 1 ; layer ++ )
2351+ {
2352+ tensor_fill (pnet -> layers [layer ].t_gradients , (real )0.0 );
2353+ tensor_fill (pnet -> layers [layer ].t_bias_grad , (real )0.0 );
2354+ }
2355+
2356+ // Copy inputs into batch input tensor
2357+ PTensor batch_input = pnet -> layers [0 ].t_batch_values ;
2358+ memcpy (batch_input -> values , inputs , actual_batch_size * input_node_count * sizeof (real ));
2359+
2360+ // Copy targets into batch target tensor
2361+ memcpy (batch_targets -> values , targets , actual_batch_size * output_node_count * sizeof (real ));
2362+
2363+ // Forward pass
2364+ eval_network_batched (pnet , actual_batch_size );
2365+
2366+ // Backward pass (computes loss and gradients)
2367+ real loss = back_propagate_batched (pnet , actual_batch_size , batch_targets );
2368+
2369+ // Increment training iteration (for Adam bias correction)
2370+ pnet -> train_iteration ++ ;
2371+
2372+ // Update weights
2373+ pnet -> optimize_func (pnet );
2374+
2375+ tensor_free (batch_targets );
2376+
2377+ return loss ;
2378+ }
2379+
2380+ //-----------------------------------------------
2381+ // End an online/incremental training session
2382+ //-----------------------------------------------
2383+ void ann_train_end (PNetwork pnet )
2384+ {
2385+ if (!pnet )
2386+ return ;
2387+
2388+ // Disable training mode (for dropout)
2389+ pnet -> is_training = 0 ;
2390+ }
2391+
22842392//------------------------------
22852393// evaluate the accuracy
22862394//------------------------------
@@ -2842,9 +2950,16 @@ int ann_predict(const PNetwork pnet, const real *inputs, real *outputs)
28422950 pnet -> layers [0 ].t_values -> values [node ] = * inputs ++ ;
28432951 }
28442952
2953+ // Temporarily disable training mode for inference (prevents dropout)
2954+ int was_training = pnet -> is_training ;
2955+ pnet -> is_training = 0 ;
2956+
28452957 // evaluate network
28462958 eval_network (pnet );
28472959
2960+ // Restore training mode
2961+ pnet -> is_training = was_training ;
2962+
28482963 // get the outputs
28492964 node_count = pnet -> layers [pnet -> layer_count - 1 ].node_count ;
28502965 for (int node = 0 ; node < node_count ; node ++ )
0 commit comments