Skip to content

Commit daab392

Browse files
committed
update mnist sample for accuracy
1 parent a974f47 commit daab392

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

mnist.c

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
static int threads = -1;
3838
static int batch_size = 32;
39-
static int epoch_count = 5;
39+
static int epoch_count = 15;
4040
static int export_onnx = 0;
4141

4242
//----------------------------------
@@ -214,10 +214,10 @@ int main(int argc, char *argv[])
214214
// make a new network
215215
PNetwork pnet = ann_make_network(OPT_ADAM, LOSS_CATEGORICAL_CROSS_ENTROPY);
216216

217-
// define our network
217+
// define our network - deeper architecture with ReLU for better accuracy
218218
ann_add_layer(pnet, 784, LAYER_INPUT, ACTIVATION_NULL);
219-
ann_add_layer(pnet, 32, LAYER_HIDDEN, ACTIVATION_SIGMOID);
220-
// ann_add_layer(pnet, 128, LAYER_HIDDEN, ACTIVATION_RELU);
219+
ann_add_layer(pnet, 128, LAYER_HIDDEN, ACTIVATION_RELU);
220+
ann_add_layer(pnet, 64, LAYER_HIDDEN, ACTIVATION_RELU);
221221
ann_add_layer(pnet, 10, LAYER_OUTPUT, ACTIVATION_SOFTMAX);
222222

223223
real *data = NULL, *test_data = NULL;
@@ -266,6 +266,8 @@ int main(int argc, char *argv[])
266266
ann_set_epoch_limit(pnet, epoch_count);
267267
ann_set_convergence(pnet, (real)0.1);
268268
ann_set_batch_size(pnet, batch_size);
269+
ann_set_dropout(pnet, 0.2f); // 20% dropout on hidden layers
270+
ann_set_gradient_clip(pnet, 5.0f); // Clip gradients for stability
269271

270272
// Add exponential LR decay (5% reduction per epoch)
271273
static real lr_decay = 0.95f;

0 commit comments

Comments
 (0)