@@ -278,25 +278,29 @@ static void record_history(PNetwork pnet, real loss, real learning_rate);
278278//------------------------------
279279static void softmax (PNetwork pnet )
280280{
281- real one_over_sum = 0.0 ;
282281 real sum = 0.0 ;
283282
284283 // find the sum of the output node values, excluding the bias node
285284 int output_layer = pnet -> layer_count - 1 ;
286285 int node_count = pnet -> layers [output_layer ].node_count ;
287-
288- //tensor_exp(pnet->layers[output_layer].t_values);
289- //one_over_sum = (real)1.0 / (tensor_sum(pnet->layers[output_layer].t_values) - 1.0);
290- //tensor_mul_scalar(pnet->layers[output_layer].t_values, one_over_sum);
286+ real * values = pnet -> layers [output_layer ].t_values -> values ;
287+
288+ // Find max for numerical stability
289+ real max_val = values [0 ];
290+ for (int node = 1 ; node < node_count ; node ++ )
291+ {
292+ if (values [node ] > max_val )
293+ max_val = values [node ];
294+ }
291295
292296 for (int node = 0 ; node < node_count ; node ++ )
293297 {
294- sum += (real )exp (pnet -> layers [ output_layer ]. t_values -> values [node ]);
298+ sum += (real )exp (values [node ] - max_val );
295299 }
296300
297301 for (int node = 0 ; node < node_count ; node ++ )
298302 {
299- pnet -> layers [ output_layer ]. t_values -> values [node ] = (real )(exp (pnet -> layers [ output_layer ]. t_values -> values [node ]) / sum );
303+ values [node ] = (real )(exp (values [node ] - max_val ) / sum );
300304 }
301305}
302306
@@ -543,7 +547,7 @@ static real compute_cross_entropy(PNetwork pnet, PTensor outputs)
543547
544548 for (int i = 0 ; i < poutput_layer -> node_count ; i ++ )
545549 {
546- xe += (real )(outputs -> values [i ] * log (poutput_layer -> t_values -> values [i ]));
550+ xe += (real )(outputs -> values [i ] * log (fmax ( poutput_layer -> t_values -> values [i ], ( real ) 1e-7 ) ));
547551 }
548552
549553 return - xe ;
@@ -598,10 +602,18 @@ static void softmax_batched(PNetwork pnet, int batch_size)
598602 real sum = (real )0.0 ;
599603 int row_offset = b * node_count ;
600604
601- // Compute exp and sum for this sample
605+ // Find max for numerical stability
606+ real max_val = batch -> values [row_offset ];
607+ for (int n = 1 ; n < node_count ; n ++ )
608+ {
609+ if (batch -> values [row_offset + n ] > max_val )
610+ max_val = batch -> values [row_offset + n ];
611+ }
612+
613+ // Compute exp(x - max) and sum for this sample
602614 for (int n = 0 ; n < node_count ; n ++ )
603615 {
604- batch -> values [row_offset + n ] = (real )exp (batch -> values [row_offset + n ]);
616+ batch -> values [row_offset + n ] = (real )exp (batch -> values [row_offset + n ] - max_val );
605617 sum += batch -> values [row_offset + n ];
606618 }
607619
@@ -3289,7 +3301,11 @@ static void write_json_float_array(FILE *fptr, const real *values, int count)
32893301 for (int i = 0 ; i < count ; i ++ )
32903302 {
32913303 if (i > 0 ) fprintf (fptr , ", " );
3292- fprintf (fptr , "%.8g" , (double )values [i ]);
3304+ double v = (double )values [i ];
3305+ if (isnan (v ) || isinf (v ))
3306+ fprintf (fptr , "0.0" );
3307+ else
3308+ fprintf (fptr , "%.8g" , v );
32933309 }
32943310 fprintf (fptr , "]" );
32953311}
0 commit comments