Skip to content

Commit 663fece

Browse files
committed
update to support online training
1 parent 8b67948 commit 663fece

File tree

5 files changed

+391
-1
lines changed

5 files changed

+391
-1
lines changed

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ add_executable(test_training_convergence test_training_convergence.c testy/test_
126126
target_link_libraries(test_training_convergence libann ${BLAS_LIB} ${MATH_LIB})
127127
add_test(NAME test_training_convergence COMMAND test_training_convergence)
128128

129+
add_executable(test_online_training test_online_training.c testy/test_main.c)
130+
target_link_libraries(test_online_training libann ${BLAS_LIB} ${MATH_LIB})
131+
add_test(NAME test_online_training COMMAND test_online_training)
132+
129133
add_executable(test_json test_json.c testy/test_main.c)
130134
target_link_libraries(test_json libann ${BLAS_LIB} ${MATH_LIB})
131135
add_test(NAME test_json COMMAND test_json)

README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ ann_save_network | save a trained network (text)
115115
ann_load_network_binary | load a previously saved network (binary)
116116
ann_save_network_binary | save a trained network (binary)
117117
ann_train_network | train a network
118+
ann_train_begin | begin an online/incremental training session
119+
ann_train_step | train one mini-batch step (online training)
120+
ann_train_end | end an online/incremental training session
118121
ann_predict | predict an output using a previously trained network
119122
ann_set_convergence | set the convergence threshold (optional)
120123
ann_evaluate_accuracy | evaluate accuracy of trained network using test data
@@ -266,6 +269,37 @@ epoch,loss,learning_rate
266269
267270
Plot with gnuplot, Python matplotlib, or Excel to diagnose training issues.
268271
272+
## Online / Incremental Training
273+
274+
For scenarios where data arrives incrementally (streaming, fine-tuning a loaded model, or user feedback), use the step-based training API:
275+
276+
```c
277+
PNetwork net = ann_make_network(OPT_ADAM, LOSS_MSE);
278+
ann_add_layer(net, 784, LAYER_INPUT, ACTIVATION_NULL);
279+
ann_add_layer(net, 128, LAYER_HIDDEN, ACTIVATION_SIGMOID);
280+
ann_add_layer(net, 10, LAYER_OUTPUT, ACTIVATION_SOFTMAX);
281+
282+
ann_train_begin(net);
283+
284+
// Feed mini-batches one at a time
285+
for (int i = 0; i < num_batches; i++)
286+
{
287+
real loss = ann_train_step(net, batch_inputs[i], batch_targets[i], batch_size);
288+
printf("Step %d loss: %f\n", i, loss);
289+
290+
// Safe to predict mid-training (dropout is auto-disabled)
291+
ann_predict(net, test_input, prediction);
292+
}
293+
294+
ann_train_end(net);
295+
```
296+
297+
Key differences from `ann_train_network()`:
298+
- **Does not reset optimizer state** — Adam momentum/variance are preserved across calls
299+
- **Does not reinitialize weights** — safe for fine-tuning loaded/pre-trained models
300+
- **Single sample training** — pass `batch_size=1` to train on individual examples
301+
- **`ann_predict()` is safe mid-training** — dropout is automatically disabled during inference
302+
269303
# Hyperparameter Tuning
270304

271305
The `ann_hypertune` module provides automated hyperparameter search to find

ann.c

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2281,6 +2281,114 @@ real ann_train_network(PNetwork pnet, PTensor inputs, PTensor outputs, int rows)
22812281
return loss;
22822282
}
22832283

2284+
//-----------------------------------------------
2285+
// Begin an online/incremental training session
2286+
//-----------------------------------------------
2287+
int ann_train_begin(PNetwork pnet)
2288+
{
2289+
if (!pnet)
2290+
return ERR_NULL_PTR;
2291+
2292+
if (pnet->layer_count <= 0 || !pnet->layers)
2293+
return ERR_INVALID;
2294+
2295+
// Enable training mode (for dropout)
2296+
pnet->is_training = 1;
2297+
2298+
// Save base learning rate for schedulers
2299+
if (pnet->base_learning_rate == (real)0.0)
2300+
pnet->base_learning_rate = pnet->learning_rate;
2301+
2302+
// Initialize weights only if not already set (e.g. loaded model)
2303+
init_weights(pnet);
2304+
2305+
// Ensure batch tensors are allocated for the configured batch size
2306+
if (ensure_batch_tensors(pnet, pnet->batchSize) != ERR_OK)
2307+
{
2308+
invoke_error_callback(ERR_ALLOC, "ann_train_begin");
2309+
return ERR_ALLOC;
2310+
}
2311+
2312+
return ERR_OK;
2313+
}
2314+
2315+
//-----------------------------------------------
2316+
// Train one mini-batch step (online training)
2317+
//-----------------------------------------------
2318+
real ann_train_step(PNetwork pnet, const real *inputs, const real *targets, int batch_size)
2319+
{
2320+
if (!pnet || !inputs || !targets)
2321+
return (real)0.0;
2322+
2323+
if (batch_size <= 0)
2324+
return (real)0.0;
2325+
2326+
int input_node_count = pnet->layers[0].node_count;
2327+
int output_node_count = pnet->layers[pnet->layer_count - 1].node_count;
2328+
2329+
unsigned actual_batch_size = (unsigned)batch_size;
2330+
2331+
// Reallocate batch tensors if batch size changed
2332+
if (pnet->current_batch_size != actual_batch_size)
2333+
{
2334+
if (ensure_batch_tensors(pnet, actual_batch_size) != ERR_OK)
2335+
{
2336+
invoke_error_callback(ERR_ALLOC, "ann_train_step");
2337+
return (real)0.0;
2338+
}
2339+
}
2340+
2341+
// Allocate temporary batch target tensor
2342+
PTensor batch_targets = tensor_create(actual_batch_size, output_node_count);
2343+
if (!batch_targets)
2344+
{
2345+
invoke_error_callback(ERR_ALLOC, "ann_train_step");
2346+
return (real)0.0;
2347+
}
2348+
2349+
// Zero gradients
2350+
for (int layer = 0; layer < pnet->layer_count - 1; layer++)
2351+
{
2352+
tensor_fill(pnet->layers[layer].t_gradients, (real)0.0);
2353+
tensor_fill(pnet->layers[layer].t_bias_grad, (real)0.0);
2354+
}
2355+
2356+
// Copy inputs into batch input tensor
2357+
PTensor batch_input = pnet->layers[0].t_batch_values;
2358+
memcpy(batch_input->values, inputs, actual_batch_size * input_node_count * sizeof(real));
2359+
2360+
// Copy targets into batch target tensor
2361+
memcpy(batch_targets->values, targets, actual_batch_size * output_node_count * sizeof(real));
2362+
2363+
// Forward pass
2364+
eval_network_batched(pnet, actual_batch_size);
2365+
2366+
// Backward pass (computes loss and gradients)
2367+
real loss = back_propagate_batched(pnet, actual_batch_size, batch_targets);
2368+
2369+
// Increment training iteration (for Adam bias correction)
2370+
pnet->train_iteration++;
2371+
2372+
// Update weights
2373+
pnet->optimize_func(pnet);
2374+
2375+
tensor_free(batch_targets);
2376+
2377+
return loss;
2378+
}
2379+
2380+
//-----------------------------------------------
2381+
// End an online/incremental training session
2382+
//-----------------------------------------------
2383+
void ann_train_end(PNetwork pnet)
2384+
{
2385+
if (!pnet)
2386+
return;
2387+
2388+
// Disable training mode (for dropout)
2389+
pnet->is_training = 0;
2390+
}
2391+
22842392
//------------------------------
22852393
// evaluate the accuracy
22862394
//------------------------------
@@ -2842,9 +2950,16 @@ int ann_predict(const PNetwork pnet, const real *inputs, real *outputs)
28422950
pnet->layers[0].t_values->values[node] = *inputs++;
28432951
}
28442952

2953+
// Temporarily disable training mode for inference (prevents dropout)
2954+
int was_training = pnet->is_training;
2955+
pnet->is_training = 0;
2956+
28452957
// evaluate network
28462958
eval_network(pnet);
28472959

2960+
// Restore training mode
2961+
pnet->is_training = was_training;
2962+
28482963
// get the outputs
28492964
node_count = pnet->layers[pnet->layer_count - 1].node_count;
28502965
for (int node = 0; node < node_count; node++)

ann.h

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,11 +497,67 @@ ANN_API PNetwork ann_load_network_binary(const char *filename);
497497
*/
498498
ANN_API real ann_train_network(PNetwork pnet, PTensor inputs, PTensor outputs, int rows);
499499

500+
/**
501+
* Begin an online/incremental training session.
502+
*
503+
* Prepares the network for step-by-step training without resetting
504+
* optimizer state (e.g. Adam momentum). Weights are initialized only
505+
* if not already set (safe for loaded/pre-trained models).
506+
*
507+
* Call ann_train_step() to train on individual mini-batches, then
508+
* ann_train_end() when finished.
509+
*
510+
* @param pnet Network to train (must have layers defined)
511+
* @return ERR_OK on success
512+
* @return ERR_NULL_PTR if pnet is NULL
513+
* @return ERR_INVALID if network has no layers
514+
* @return ERR_ALLOC if batch tensor allocation fails
515+
*
516+
* @see ann_train_step() to train on a single mini-batch
517+
* @see ann_train_end() to finish the training session
518+
*/
519+
ANN_API int ann_train_begin(PNetwork pnet);
520+
521+
/**
522+
* Train on a single mini-batch (online/incremental training).
523+
*
524+
* Performs one forward pass, backward pass, and weight update on the
525+
* provided mini-batch. Does not reset optimizer state between calls,
526+
* enabling incremental learning on streaming data.
527+
*
528+
* Must be called between ann_train_begin() and ann_train_end().
529+
*
530+
* @param pnet Network being trained
531+
* @param inputs Input data (batch_size consecutive input vectors, each of
532+
* size = first layer node_count)
533+
* @param targets Target data (batch_size consecutive target vectors, each of
534+
* size = last layer node_count)
535+
* @param batch_size Number of samples in this mini-batch
536+
* @return Loss for this mini-batch, or 0.0 on error
537+
*
538+
* @see ann_train_begin() to start a training session
539+
* @see ann_train_end() to finish the training session
540+
*/
541+
ANN_API real ann_train_step(PNetwork pnet, const real *inputs, const real *targets, int batch_size);
542+
543+
/**
544+
* End an online/incremental training session.
545+
*
546+
* Disables training mode (stops dropout from being applied).
547+
* The network is ready for inference after this call.
548+
*
549+
* @param pnet Network that was being trained
550+
*
551+
* @see ann_train_begin() to start a training session
552+
*/
553+
ANN_API void ann_train_end(PNetwork pnet);
554+
500555
/**
501556
* Run trained network on single input to produce output.
502557
*
503558
* Forward-propagates input through all layers and returns the output
504-
* layer activations. Network must be trained before calling.
559+
* layer activations. Safe to call during online training (between
560+
* ann_train_begin/ann_train_end) — dropout is automatically disabled.
505561
*
506562
* @param pnet Trained network (must not be NULL)
507563
* @param inputs Input feature vector (size = first layer node_count)

0 commit comments

Comments
 (0)