Skip to content

Commit 51b1312

Browse files
authored
Merge pull request #31 from Advaitgaur004/gradfn_softmax_fix
[Fix] : Correct Softmax Gradient
2 parents 2a5ecd2 + cff22a3 commit 51b1312

File tree

4 files changed

+69
-47
lines changed

4 files changed

+69
-47
lines changed

include/cten.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ typedef struct GradNode {
3535
struct Tensor inputs[4];
3636
int n_inputs;
3737
const char* name;
38+
int params[4];
3839
} GradNode;
3940

4041
typedef struct {
@@ -111,7 +112,7 @@ Tensor nn_sigmoid(Tensor input);
111112
Tensor nn_tanh(Tensor input);
112113
Tensor nn_elu(Tensor self, float alpha);
113114
Tensor nn_selu(Tensor self);
114-
Tensor nn_softmax(Tensor input);
115+
Tensor nn_softmax(Tensor input, int dim);
115116
Tensor Glorot_init(TensorShape shape, bool requires_grad);
116117
Tensor nn_crossentropy(Tensor y_true, Tensor y_pred);
117118
Tensor nn_softmax_crossentropy(Tensor y_true, Tensor logits);

src/basic.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,9 @@ void Tensor_backward(Tensor self, Tensor grad) {
175175

176176
// Step 2: Apply the chain rule (upstream_grad * local_grad)
177177
Tensor combined_grad;
178-
if(strcmp(self.node->name, "Matmul") == 0) {
178+
if (strcmp(self.node->name, "Softmax") == 0) {
179+
combined_grad = input_grad;
180+
} else if(strcmp(self.node->name, "Matmul") == 0) {
179181
if (i == 0) {
180182
combined_grad = Tensor_matmul(grad, input_grad);
181183
} else {

src/nn.c

Lines changed: 64 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -302,57 +302,76 @@ static Tensor GradFn_softmax(Tensor self, int i) {
302302
Tensor input = self.node->inputs[i];
303303
Tensor grad = Tensor_new(input.shape, false);
304304

305-
int dim = TensorShape_dim(self.shape);
306-
int batch_size = self.shape[0];
307-
int num_classes = self.shape[1];
308-
for(int b = 0; b < batch_size; b++){
309-
for(int i = 0; i < num_classes; i++) {
310-
for(int j = 0; j < num_classes; j++) {
311-
float softmax_i = self.data->flex[b * num_classes + i];
312-
float softmax_j = self.data->flex[b * num_classes + j];
313-
float value;
314-
if(i == j){
315-
value = softmax_i * (1.0f - softmax_i);
316-
}
317-
else{
318-
value = -softmax_i * softmax_j;
319-
}
320-
321-
if(i == j){
322-
grad.data->flex[b * num_classes + i] = value;
323-
}
305+
int dim = self.node->params[0];
306+
int input_ndim = TensorShape_dim(input.shape);
307+
308+
int dim_size = self.shape[dim];
309+
int outer_size = 1;
310+
for(int j = 0; j < dim; j++) {
311+
outer_size *= self.shape[j];
312+
}
313+
int inner_size = 1;
314+
for(int j = dim + 1; j < input_ndim; j++) {
315+
inner_size *= self.shape[j];
316+
}
317+
318+
float* s_data = self.data->flex; // Softmax output data (s)
319+
float* upstream_grad_data = self.node->grad.data->flex; // Upstream grad (dL/ds)
320+
float* input_grad_data = grad.data->flex; // Resulting grad (dL/dz)
321+
for (int outer = 0; outer < outer_size; outer++) {
322+
for (int inner = 0; inner < inner_size; inner++) {
323+
int slice_offset = outer * dim_size * inner_size + inner;
324+
// Step 1. Calculate the dot product for the current slice: sum_k(dL/ds_k * s_k)
325+
float dot_product = 0.0f;
326+
for (int k = 0; k < dim_size; k++) {
327+
int index = slice_offset + k * inner_size;
328+
dot_product += upstream_grad_data[index] * s_data[index];
329+
}
330+
331+
// Step 2. Calculate the final gradient using the formula: dL/dz_j = s_j * (dL/ds_j - dot_product)
332+
for (int k = 0; k < dim_size; k++) {
333+
int index = slice_offset + k * inner_size;
334+
input_grad_data[index] = s_data[index] * (upstream_grad_data[index] - dot_product);
324335
}
325336
}
326337
}
327338
return grad;
328339
}
329340

330-
Tensor nn_softmax(Tensor self) {
341+
Tensor nn_softmax(Tensor self, int dim) {
331342
bool requires_grad = !cten_is_eval() && self.node != NULL;
332343
Tensor res = Tensor_new(self.shape, requires_grad);
333344
int self_dim = TensorShape_dim(self.shape);
334-
assert(self_dim > 0);
335-
int last_dim_size = self.shape[self_dim - 1];
336-
int outer_size = self.data->numel / last_dim_size;
337-
345+
assert(dim >= 0 && dim < self_dim);
346+
int dim_size = self.shape[dim];
347+
int outer_size = 1;
348+
for(int i = 0; i < dim; i++) {
349+
outer_size *= self.shape[i];
350+
}
351+
int inner_size = 1;
352+
for(int i = dim + 1; i < self_dim; i++) {
353+
inner_size *= self.shape[i];
354+
}
355+
338356
for(int outer = 0; outer < outer_size; outer++) {
339-
float max_val = -INFINITY;
340-
float sum = 0;
341-
342-
for(int d = 0; d < last_dim_size; d++) {
343-
int index = outer * last_dim_size + d;
344-
max_val = fmaxf(max_val, self.data->flex[index]);
345-
}
346-
347-
for(int d = 0; d < last_dim_size; d++) {
348-
int index = outer * last_dim_size + d;
349-
res.data->flex[index] = expf(self.data->flex[index] - max_val);
350-
sum += res.data->flex[index];
351-
}
352-
353-
for(int d = 0; d < last_dim_size; d++) {
354-
int index = outer * last_dim_size + d;
355-
res.data->flex[index] /= sum;
357+
for(int inner = 0; inner < inner_size; inner++) {
358+
int slice_offset = outer * dim_size * inner_size + inner;
359+
float max_val = -INFINITY;
360+
for(int k = 0; k < dim_size; k++) {
361+
int index = slice_offset + k * inner_size;
362+
max_val = fmaxf(max_val, self.data->flex[index]);
363+
}
364+
float sum = 0.0f;
365+
for(int k = 0; k < dim_size; k++) {
366+
int index = slice_offset + k * inner_size;
367+
float val = expf(self.data->flex[index] - max_val);
368+
res.data->flex[index] = val;
369+
sum += val;
370+
}
371+
for(int k = 0; k < dim_size; k++) {
372+
int index = slice_offset + k * inner_size;
373+
res.data->flex[index] /= sum;
374+
}
356375
}
357376
}
358377

@@ -361,6 +380,7 @@ Tensor nn_softmax(Tensor self) {
361380
res.node->inputs[0] = self;
362381
res.node->n_inputs = 1;
363382
res.node->name = "Softmax";
383+
res.node->params[0] = dim;
364384
}
365385
return res;
366386
}
@@ -482,8 +502,9 @@ static Tensor GradFn_softmax_crossentropy(Tensor self, int i) {
482502
Tensor nn_softmax_crossentropy(Tensor y_true, Tensor logits) {
483503
bool requires_grad = !cten_is_eval() && logits.node != NULL;
484504
//disable gradient computation
485-
cten_begin_eval();
486-
Tensor y_pred = nn_softmax(logits);
505+
cten_begin_eval();
506+
int last_dim_logits = TensorShape_dim(logits.shape) - 1;
507+
Tensor y_pred = nn_softmax(logits, last_dim_logits);
487508
Tensor loss = nn_crossentropy(y_true, y_pred);
488509
cten_end_eval();
489510
Tensor res = Tensor_zeros((TensorShape){1}, requires_grad);

src/operator.c

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ Tensor Tensor_mul(Tensor self, Tensor other) {
8282
return res;
8383
}
8484

85-
8685
Tensor Tensor_mulf(Tensor self, float other) {
8786
Tensor tmp = Tensor_new(self.shape, false);
8887
for(int i = 0; i < tmp.data->numel; i++) {
@@ -283,7 +282,6 @@ static Tensor GradFn_sub(Tensor self, int i) {
283282
return res;
284283
}
285284

286-
287285
static Tensor GradFn_div(Tensor self, int i) {
288286
Tensor res = Tensor_new(self.shape, false);
289287
Tensor x = self.node->inputs[0];

0 commit comments

Comments
 (0)