Skip to content

Commit 2a5ecd2

Browse files
authored
Merge pull request #30 from Advaitgaur004/loss-function
[Feat] : Add MSE, MAE, and Huber Loss Functions
2 parents 02dc6d0 + 0c117a5 commit 2a5ecd2

File tree

2 files changed

+141
-2
lines changed

2 files changed

+141
-2
lines changed

include/cten.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ Tensor nn_softmax(Tensor input);
115115
Tensor Glorot_init(TensorShape shape, bool requires_grad);
116116
Tensor nn_crossentropy(Tensor y_true, Tensor y_pred);
117117
Tensor nn_softmax_crossentropy(Tensor y_true, Tensor logits);
118+
Tensor nn_mse_loss(Tensor y_true, Tensor y_pred);
119+
Tensor nn_mae_loss(Tensor y_true, Tensor y_pred);
120+
Tensor nn_huber_loss(Tensor y_true, Tensor y_pred, float delta);
118121

119122
/* Memory Management */
120123
typedef int64_t PoolId;

src/nn.c

Lines changed: 138 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
#include <stdio.h>
1010

1111
static float elu_alpha_value = 1.0f;
12+
static float huber_delta_value = 1.0f;
1213

1314
Tensor nn_linear(Tensor input, Tensor weight, Tensor bias) {
1415
Tensor tmp = Tensor_matmul(input, weight);
1516
tmp = Tensor_add(tmp, bias);
1617
return tmp;
1718
}
1819

19-
/* nn.relu */
2020
static Tensor GradFn_relu(Tensor self, int i) {
2121
Tensor input = self.node->inputs[i];
2222
Tensor res = Tensor_new(input.shape, false);
@@ -365,7 +365,6 @@ Tensor nn_softmax(Tensor self) {
365365
return res;
366366
}
367367

368-
/* nn.cross_entropy */
369368
static Tensor GradFn_crossentropy(Tensor self, int i) {
370369
if (i == 1) { // Gradient w.r.t. y_pred
371370
Tensor y_true = self.node->inputs[0];
@@ -498,5 +497,142 @@ Tensor nn_softmax_crossentropy(Tensor y_true, Tensor logits) {
498497
res.node->name = "SoftmaxCrossEntropy";
499498
}
500499

500+
return res;
501+
}
502+
503+
static Tensor GradFn_mse_loss(Tensor self, int i) {
504+
if (i == 1) { // Gradient w.r.t y_pred
505+
Tensor y_true = self.node->inputs[0];
506+
Tensor y_pred = self.node->inputs[1];
507+
int n = y_pred.data->numel;
508+
509+
Tensor grad = Tensor_new(y_pred.shape, false);
510+
for (int j = 0; j < n; j++) {
511+
grad.data->flex[j] = 2.0f * (y_pred.data->flex[j] - y_true.data->flex[j]) / n;
512+
}
513+
return grad;
514+
}
515+
return Tensor_zeros((TensorShape){1}, false);
516+
}
517+
518+
Tensor nn_mse_loss(Tensor y_true, Tensor y_pred) {
519+
bool requires_grad = !cten_is_eval() && y_pred.node != NULL;
520+
521+
cten_begin_eval();
522+
Tensor error = Tensor_sub(y_pred, y_true);
523+
Tensor squared_error = Tensor_square(error);
524+
Tensor loss = Tensor_mean(squared_error);
525+
cten_end_eval();
526+
527+
Tensor res = Tensor_new((TensorShape){1}, requires_grad);
528+
res.data->flex[0] = loss.data->flex[0];
529+
530+
if (requires_grad) {
531+
res.node->grad_fn = GradFn_mse_loss;
532+
res.node->inputs[0] = y_true;
533+
res.node->inputs[1] = y_pred;
534+
res.node->n_inputs = 2;
535+
res.node->name = "MSELoss";
536+
}
537+
return res;
538+
}
539+
540+
static Tensor GradFn_mae_loss(Tensor self, int i) {
541+
if (i == 1) { // Gradient w.r.t y_pred
542+
Tensor y_true = self.node->inputs[0];
543+
Tensor y_pred = self.node->inputs[1];
544+
int n = y_pred.data->numel;
545+
546+
Tensor grad = Tensor_new(y_pred.shape, false);
547+
for (int j = 0; j < n; j++) {
548+
float error = y_pred.data->flex[j] - y_true.data->flex[j];
549+
if (error > 0) {
550+
grad.data->flex[j] = 1.0f / n;
551+
} else if (error < 0) {
552+
grad.data->flex[j] = -1.0f / n;
553+
} else {
554+
grad.data->flex[j] = 0.0f;
555+
}
556+
}
557+
return grad;
558+
}
559+
return Tensor_zeros((TensorShape){1}, false);
560+
}
561+
562+
Tensor nn_mae_loss(Tensor y_true, Tensor y_pred) {
563+
bool requires_grad = !cten_is_eval() && y_pred.node != NULL;
564+
565+
cten_begin_eval();
566+
Tensor error = Tensor_sub(y_pred, y_true);
567+
Tensor abs_error = Tensor_abs(error);
568+
Tensor loss = Tensor_mean(abs_error);
569+
cten_end_eval();
570+
571+
Tensor res = Tensor_new((TensorShape){1}, requires_grad);
572+
res.data->flex[0] = loss.data->flex[0];
573+
574+
if (requires_grad) {
575+
res.node->grad_fn = GradFn_mae_loss;
576+
res.node->inputs[0] = y_true;
577+
res.node->inputs[1] = y_pred;
578+
res.node->n_inputs = 2;
579+
res.node->name = "MAELoss";
580+
}
581+
return res;
582+
}
583+
584+
static Tensor GradFn_huber_loss(Tensor self, int i) {
585+
if (i == 1) { // Gradient w.r.t y_pred
586+
Tensor y_true = self.node->inputs[0];
587+
Tensor y_pred = self.node->inputs[1];
588+
float delta = huber_delta_value;
589+
int n = y_pred.data->numel;
590+
591+
Tensor grad = Tensor_new(y_pred.shape, false);
592+
// Gradient of Huber loss is (error / n) for small errors,
593+
// and (delta * sign(error) / n) for large errors.
594+
for (int j = 0; j < n; j++) {
595+
float error = y_pred.data->flex[j] - y_true.data->flex[j];
596+
if (fabsf(error) <= delta) {
597+
grad.data->flex[j] = error / n;
598+
} else {
599+
if (error > 0) {
600+
grad.data->flex[j] = delta / n;
601+
} else {
602+
grad.data->flex[j] = -delta / n;
603+
}
604+
}
605+
}
606+
return grad;
607+
}
608+
return Tensor_zeros((TensorShape){1}, false);
609+
}
610+
611+
Tensor nn_huber_loss(Tensor y_true, Tensor y_pred, float delta) {
612+
huber_delta_value = delta; // Store delta for the backward pass
613+
bool requires_grad = !cten_is_eval() && y_pred.node != NULL;
614+
615+
int n = y_pred.data->numel;
616+
float total_loss = 0.0f;
617+
for (int i = 0; i < n; i++) {
618+
float error = y_pred.data->flex[i] - y_true.data->flex[i];
619+
float abs_error = fabsf(error);
620+
if (abs_error <= delta) {
621+
total_loss += 0.5f * error * error; // MSE part
622+
} else {
623+
total_loss += delta * (abs_error - 0.5f * delta); // MAE part
624+
}
625+
}
626+
627+
Tensor res = Tensor_new((TensorShape){1}, requires_grad);
628+
res.data->flex[0] = total_loss / n; // Mean Huber Loss
629+
630+
if (requires_grad) {
631+
res.node->grad_fn = GradFn_huber_loss;
632+
res.node->inputs[0] = y_true;
633+
res.node->inputs[1] = y_pred;
634+
res.node->n_inputs = 2;
635+
res.node->name = "HuberLoss";
636+
}
501637
return res;
502638
}

0 commit comments

Comments
 (0)