Skip to content

Commit 8e776a2

Browse files
committed
fix NaN/Inf issue with exp
1 parent 663fece commit 8e776a2

File tree

2 files changed

+51
-11
lines changed

2 files changed

+51
-11
lines changed

ann.c

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -278,25 +278,29 @@ static void record_history(PNetwork pnet, real loss, real learning_rate);
278278
//------------------------------
279279
static void softmax(PNetwork pnet)
280280
{
281-
real one_over_sum = 0.0;
282281
real sum = 0.0;
283282

284283
// find the sum of the output node values, excluding the bias node
285284
int output_layer = pnet->layer_count - 1;
286285
int node_count = pnet->layers[output_layer].node_count;
287-
288-
//tensor_exp(pnet->layers[output_layer].t_values);
289-
//one_over_sum = (real)1.0 / (tensor_sum(pnet->layers[output_layer].t_values) - 1.0);
290-
//tensor_mul_scalar(pnet->layers[output_layer].t_values, one_over_sum);
286+
real *values = pnet->layers[output_layer].t_values->values;
287+
288+
// Find max for numerical stability
289+
real max_val = values[0];
290+
for (int node = 1; node < node_count; node++)
291+
{
292+
if (values[node] > max_val)
293+
max_val = values[node];
294+
}
291295

292296
for (int node = 0; node < node_count; node++)
293297
{
294-
sum += (real)exp(pnet->layers[output_layer].t_values->values[node]);
298+
sum += (real)exp(values[node] - max_val);
295299
}
296300

297301
for (int node = 0; node < node_count; node++)
298302
{
299-
pnet->layers[output_layer].t_values->values[node] = (real)(exp(pnet->layers[output_layer].t_values->values[node]) / sum);
303+
values[node] = (real)(exp(values[node] - max_val) / sum);
300304
}
301305
}
302306

@@ -543,7 +547,7 @@ static real compute_cross_entropy(PNetwork pnet, PTensor outputs)
543547

544548
for (int i = 0; i < poutput_layer->node_count; i++)
545549
{
546-
xe += (real)(outputs->values[i] * log(poutput_layer->t_values->values[i]));
550+
xe += (real)(outputs->values[i] * log(fmax(poutput_layer->t_values->values[i], (real)1e-7)));
547551
}
548552

549553
return -xe;
@@ -598,10 +602,18 @@ static void softmax_batched(PNetwork pnet, int batch_size)
598602
real sum = (real)0.0;
599603
int row_offset = b * node_count;
600604

601-
// Compute exp and sum for this sample
605+
// Find max for numerical stability
606+
real max_val = batch->values[row_offset];
607+
for (int n = 1; n < node_count; n++)
608+
{
609+
if (batch->values[row_offset + n] > max_val)
610+
max_val = batch->values[row_offset + n];
611+
}
612+
613+
// Compute exp(x - max) and sum for this sample
602614
for (int n = 0; n < node_count; n++)
603615
{
604-
batch->values[row_offset + n] = (real)exp(batch->values[row_offset + n]);
616+
batch->values[row_offset + n] = (real)exp(batch->values[row_offset + n] - max_val);
605617
sum += batch->values[row_offset + n];
606618
}
607619

@@ -3289,7 +3301,11 @@ static void write_json_float_array(FILE *fptr, const real *values, int count)
32893301
for (int i = 0; i < count; i++)
32903302
{
32913303
if (i > 0) fprintf(fptr, ", ");
3292-
fprintf(fptr, "%.8g", (double)values[i]);
3304+
double v = (double)values[i];
3305+
if (isnan(v) || isinf(v))
3306+
fprintf(fptr, "0.0");
3307+
else
3308+
fprintf(fptr, "%.8g", v);
32933309
}
32943310
fprintf(fptr, "]");
32953311
}

test_onnx_export.c

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,30 @@ void test_main(int argc, char *argv[]) {
288288

289289
remove_file(import_file);
290290

291+
// ========================================================================
292+
// NaN/Inf HANDLING TESTS
293+
// ========================================================================
294+
SUITE("ONNX Export NaN/Inf Handling");
295+
COMMENT("Testing that NaN/Inf values are serialized as valid JSON...");
296+
297+
net = ann_make_network(OPT_SGD, LOSS_MSE);
298+
ann_add_layer(net, 2, LAYER_INPUT, ACTIVATION_NULL);
299+
ann_add_layer(net, 2, LAYER_OUTPUT, ACTIVATION_SIGMOID);
300+
301+
// Inject NaN and Inf into weights
302+
net->layers[0].t_weights->values[0] = NAN;
303+
net->layers[0].t_weights->values[1] = INFINITY;
304+
net->layers[0].t_bias->values[0] = -INFINITY;
305+
306+
result = ann_export_onnx(net, test_file);
307+
TESTEX("Export with NaN/Inf values returns ERR_OK", (result == ERR_OK));
308+
TESTEX("Output does not contain 'nan'", !file_contains(test_file, "nan"));
309+
TESTEX("Output does not contain 'inf'", !file_contains(test_file, "inf"));
310+
TESTEX("Output contains 0.0 replacement", file_contains(test_file, "0.0"));
311+
312+
remove_file(test_file);
313+
ann_free_network(net);
314+
291315
// Cleanup any remaining test files
292316
remove_file(test_file);
293317
remove_file(pikchr_file);

0 commit comments

Comments
 (0)