Skip to content

Commit 9d99314

Browse files
committed
Ref : New implementation with updated APIs
1 parent 5cfee53 commit 9d99314

File tree

1 file changed

+17
-50
lines changed

1 file changed

+17
-50
lines changed

examples/nn_training_example.c

Lines changed: 17 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,38 @@
11
#include <stdio.h>
22
#include <stdlib.h>
3-
#include <time.h>
43
#include "../include/Core/training.h"
4+
#include "../include/Core/dataset.h"
55

66
int main()
77
{
8-
srand(time(NULL));
98
NeuralNetwork *network = create_neural_network(2);
10-
11-
build_network(network, OPTIMIZER_ADAM, 0.01f, LOSS_MSE, 0.0f, 0.0f);
9+
build_network(network, OPTIMIZER_ADAM, 0.1f, LOSS_MSE, 0.0f, 0.0f);
1210
model_add(network, LAYER_DENSE, ACTIVATION_RELU, 2, 4, 0.0f, 0, 0);
1311
model_add(network, LAYER_DENSE, ACTIVATION_TANH, 4, 4, 0.0f, 0, 0);
1412
model_add(network, LAYER_DENSE, ACTIVATION_SIGMOID, 4, 1, 0.0f, 0, 0);
1513

16-
int num_samples = 4;
17-
float **X_train = (float **)cm_safe_malloc(num_samples * sizeof(float *), __FILE__, __LINE__);
18-
float **y_train = (float **)cm_safe_malloc(num_samples * sizeof(float *), __FILE__, __LINE__);
19-
20-
for (int i = 0; i < num_samples; i++)
21-
{
22-
X_train[i] = (float *)cm_safe_malloc(2 * sizeof(float), __FILE__, __LINE__);
23-
y_train[i] = (float *)cm_safe_malloc(1 * sizeof(float), __FILE__, __LINE__);
24-
}
14+
float X_data[4][2] = {
15+
{0.0f, 0.0f},
16+
{0.0f, 1.0f},
17+
{1.0f, 0.0f},
18+
{1.0f, 1.0f}};
2519

26-
X_train[0][0] = 0.0f;
27-
X_train[0][1] = 0.0f;
28-
y_train[0][0] = 0.0f;
29-
X_train[1][0] = 0.0f;
30-
X_train[1][1] = 1.0f;
31-
y_train[1][0] = 1.0f;
20+
float y_data[4][1] = {
21+
{0.0f},
22+
{1.0f},
23+
{1.0f},
24+
{1.0f}};
3225

33-
X_train[2][0] = 1.0f;
34-
X_train[2][1] = 0.0f;
35-
y_train[2][0] = 1.0f;
36-
37-
X_train[3][0] = 1.0f;
38-
X_train[3][1] = 1.0f;
39-
y_train[3][0] = 1.0f;
26+
Dataset *dataset = dataset_create();
27+
dataset_load_arrays(dataset, (float *)X_data, (float *)y_data, 4, 2, 1);
4028

4129
summary(network);
42-
train_network(network, X_train, y_train, num_samples, 2, 1, 1, 300);
43-
44-
MetricType metrics[] = {METRIC_R2_SCORE};
4530

46-
int num_metrics = sizeof(metrics) / sizeof(metrics[0]);
47-
float results[num_metrics];
48-
49-
test_network(network, X_train, y_train, num_samples, 2, 1, (int *)metrics, num_metrics, results);
50-
printf("R2 Score: %.2f\n", results[0]);
51-
52-
for (int i = 0; i < num_samples; i++)
53-
{
54-
float prediction = 0.0f;
55-
forward_pass(network, X_train[i], &prediction, 2, 1, 0);
56-
printf("Input: [%.0f, %.0f], Expected: %.0f, Predicted: %.4f\n",
57-
X_train[i][0], X_train[i][1], y_train[i][0], prediction);
58-
}
31+
train_network(network, dataset, 30);
32+
test_network(network, dataset->X, dataset->y, dataset->num_samples, NULL);
5933

34+
dataset_free(dataset);
6035
free_neural_network(network);
6136

62-
for (int i = 0; i < num_samples; i++)
63-
{
64-
cm_safe_free((void **)&X_train[i]);
65-
cm_safe_free((void **)&y_train[i]);
66-
}
67-
cm_safe_free((void **)&X_train);
68-
cm_safe_free((void **)&y_train);
69-
7037
return 0;
7138
}

0 commit comments

Comments
 (0)