@@ -302,57 +302,76 @@ static Tensor GradFn_softmax(Tensor self, int i) {
302302 Tensor input = self .node -> inputs [i ];
303303 Tensor grad = Tensor_new (input .shape , false);
304304
305- int dim = TensorShape_dim (self .shape );
306- int batch_size = self .shape [0 ];
307- int num_classes = self .shape [1 ];
308- for (int b = 0 ; b < batch_size ; b ++ ){
309- for (int i = 0 ; i < num_classes ; i ++ ) {
310- for (int j = 0 ; j < num_classes ; j ++ ) {
311- float softmax_i = self .data -> flex [b * num_classes + i ];
312- float softmax_j = self .data -> flex [b * num_classes + j ];
313- float value ;
314- if (i == j ){
315- value = softmax_i * (1.0f - softmax_i );
316- }
317- else {
318- value = - softmax_i * softmax_j ;
319- }
320-
321- if (i == j ){
322- grad .data -> flex [b * num_classes + i ] = value ;
323- }
305+ int dim = self .node -> params [0 ];
306+ int input_ndim = TensorShape_dim (input .shape );
307+
308+ int dim_size = self .shape [dim ];
309+ int outer_size = 1 ;
310+ for (int j = 0 ; j < dim ; j ++ ) {
311+ outer_size *= self .shape [j ];
312+ }
313+ int inner_size = 1 ;
314+ for (int j = dim + 1 ; j < input_ndim ; j ++ ) {
315+ inner_size *= self .shape [j ];
316+ }
317+
318+ float * s_data = self .data -> flex ; // Softmax output data (s)
319+ float * upstream_grad_data = self .node -> grad .data -> flex ; // Upstream grad (dL/ds)
320+ float * input_grad_data = grad .data -> flex ; // Resulting grad (dL/dz)
321+ for (int outer = 0 ; outer < outer_size ; outer ++ ) {
322+ for (int inner = 0 ; inner < inner_size ; inner ++ ) {
323+ int slice_offset = outer * dim_size * inner_size + inner ;
324+ // Step 1. Calculate the dot product for the current slice: sum_k(dL/ds_k * s_k)
325+ float dot_product = 0.0f ;
326+ for (int k = 0 ; k < dim_size ; k ++ ) {
327+ int index = slice_offset + k * inner_size ;
328+ dot_product += upstream_grad_data [index ] * s_data [index ];
329+ }
330+
331+ // Step 2. Calculate the final gradient using the formula: dL/dz_j = s_j * (dL/ds_j - dot_product)
332+ for (int k = 0 ; k < dim_size ; k ++ ) {
333+ int index = slice_offset + k * inner_size ;
334+ input_grad_data [index ] = s_data [index ] * (upstream_grad_data [index ] - dot_product );
324335 }
325336 }
326337 }
327338 return grad ;
328339}
329340
330- Tensor nn_softmax (Tensor self ) {
341+ Tensor nn_softmax (Tensor self , int dim ) {
331342 bool requires_grad = !cten_is_eval () && self .node != NULL ;
332343 Tensor res = Tensor_new (self .shape , requires_grad );
333344 int self_dim = TensorShape_dim (self .shape );
334- assert (self_dim > 0 );
335- int last_dim_size = self .shape [self_dim - 1 ];
336- int outer_size = self .data -> numel / last_dim_size ;
337-
345+ assert (dim >= 0 && dim < self_dim );
346+ int dim_size = self .shape [dim ];
347+ int outer_size = 1 ;
348+ for (int i = 0 ; i < dim ; i ++ ) {
349+ outer_size *= self .shape [i ];
350+ }
351+ int inner_size = 1 ;
352+ for (int i = dim + 1 ; i < self_dim ; i ++ ) {
353+ inner_size *= self .shape [i ];
354+ }
355+
338356 for (int outer = 0 ; outer < outer_size ; outer ++ ) {
339- float max_val = - INFINITY ;
340- float sum = 0 ;
341-
342- for (int d = 0 ; d < last_dim_size ; d ++ ) {
343- int index = outer * last_dim_size + d ;
344- max_val = fmaxf (max_val , self .data -> flex [index ]);
345- }
346-
347- for (int d = 0 ; d < last_dim_size ; d ++ ) {
348- int index = outer * last_dim_size + d ;
349- res .data -> flex [index ] = expf (self .data -> flex [index ] - max_val );
350- sum += res .data -> flex [index ];
351- }
352-
353- for (int d = 0 ; d < last_dim_size ; d ++ ) {
354- int index = outer * last_dim_size + d ;
355- res .data -> flex [index ] /= sum ;
357+ for (int inner = 0 ; inner < inner_size ; inner ++ ) {
358+ int slice_offset = outer * dim_size * inner_size + inner ;
359+ float max_val = - INFINITY ;
360+ for (int k = 0 ; k < dim_size ; k ++ ) {
361+ int index = slice_offset + k * inner_size ;
362+ max_val = fmaxf (max_val , self .data -> flex [index ]);
363+ }
364+ float sum = 0.0f ;
365+ for (int k = 0 ; k < dim_size ; k ++ ) {
366+ int index = slice_offset + k * inner_size ;
367+ float val = expf (self .data -> flex [index ] - max_val );
368+ res .data -> flex [index ] = val ;
369+ sum += val ;
370+ }
371+ for (int k = 0 ; k < dim_size ; k ++ ) {
372+ int index = slice_offset + k * inner_size ;
373+ res .data -> flex [index ] /= sum ;
374+ }
356375 }
357376 }
358377
@@ -361,6 +380,7 @@ Tensor nn_softmax(Tensor self) {
361380 res .node -> inputs [0 ] = self ;
362381 res .node -> n_inputs = 1 ;
363382 res .node -> name = "Softmax" ;
383+ res .node -> params [0 ] = dim ;
364384 }
365385 return res ;
366386}
@@ -482,8 +502,9 @@ static Tensor GradFn_softmax_crossentropy(Tensor self, int i) {
482502Tensor nn_softmax_crossentropy (Tensor y_true , Tensor logits ) {
483503 bool requires_grad = !cten_is_eval () && logits .node != NULL ;
484504 //disable gradient computation
485- cten_begin_eval ();
486- Tensor y_pred = nn_softmax (logits );
505+ cten_begin_eval ();
506+ int last_dim_logits = TensorShape_dim (logits .shape ) - 1 ;
507+ Tensor y_pred = nn_softmax (logits , last_dim_logits );
487508 Tensor loss = nn_crossentropy (y_true , y_pred );
488509 cten_end_eval ();
489510 Tensor res = Tensor_zeros ((TensorShape ){1 }, requires_grad );
0 commit comments