Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 32 additions & 33 deletions src/operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -285,41 +285,40 @@ static Tensor GradFn_sub(Tensor self, int i) {


static Tensor GradFn_div(Tensor self, int i) {
// f(x, y) = x / y; f'(x) = 1/y; f'(y) = -x/y²
if (i == 0) {
// Gradient w.r.t. x: 1/y
Tensor y = self.node->inputs[1];
Tensor res = Tensor_new(y.shape, false);
Tensor res = Tensor_new(self.shape, false);
Tensor x = self.node->inputs[0];
Tensor y = self.node->inputs[1];

if (i == 0) { // Gradient w.r.t. x: 1/y
for (int j = 0; j < res.data->numel; j++) {
res.data->flex[j] = 1.0f / y.data->flex[j];
res.data->flex[j] = 1.0f / y.data->flex[j % y.data->numel];
}
return res;
} else {
// Gradient w.r.t. y: -x/y²
Tensor x = self.node->inputs[0];
Tensor y = self.node->inputs[1];
Tensor res = Tensor_new(y.shape, false);
} else { // Gradient w.r.t. y: -x/y²
for (int j = 0; j < res.data->numel; j++) {
float y_val = y.data->flex[j];
res.data->flex[j] = -x.data->flex[j] / (y_val * y_val);
float x_val = x.data->flex[j % x.data->numel];
float y_val = y.data->flex[j % y.data->numel];
res.data->flex[j] = -x_val / (y_val * y_val);
}
return res;
}
return res;
}

Tensor Tensor_div(Tensor self, Tensor other) {
Tensor orig_self = self;
Tensor orig_other = other;

if (!cten_elemwise_broadcast(&self, &other)) {
cten_assert_shape("Tensor_div() cannot broadcast", self.shape, other.shape);
cten_assert_shape("Tensor_div() cannot broadcast", orig_self.shape, orig_other.shape);
}
bool requires_grad = !cten_is_eval() && (self.node != NULL || other.node != NULL);
bool requires_grad = !cten_is_eval() && (orig_self.node != NULL || orig_other.node != NULL);
Tensor res = Tensor_new(self.shape, requires_grad);
for (int i = 0; i < self.data->numel; i++) {
res.data->flex[i] = self.data->flex[i] / other.data->flex[i];
}
if (requires_grad) {
res.node->grad_fn = GradFn_div;
res.node->inputs[0] = self;
res.node->inputs[1] = other;
res.node->inputs[0] = orig_self;
res.node->inputs[1] = orig_other;
res.node->n_inputs = 2;
res.node->name = "Div";
}
Expand Down Expand Up @@ -379,29 +378,27 @@ Tensor Tensor_reciprocal(Tensor self) {
}

static Tensor GradFn_pow(Tensor self, int i) {
// f(x, y) = x^y; f'(x) = y*x^(y-1); f'(y) = x^y * ln(x)
// f(x, y) = x^y; ∂f/∂x = y*x^(y-1); ∂f/∂y = x^y * ln(x)
Tensor res = Tensor_new(self.shape, false);
Tensor x = self.node->inputs[0];
Tensor y = self.node->inputs[1];

if (i == 0) {
// Gradient w.r.t. x: y*x^(y-1)
Tensor res = Tensor_new(x.shape, false);
for (int j = 0; j < res.data->numel; j++) {
float x_val = x.data->flex[j];
float y_val = y.data->flex[j];
float x_val = x.data->flex[j % x.data->numel];
float y_val = y.data->flex[j % y.data->numel];
if (x_val == 0.0f && y_val > 1.0f) {
res.data->flex[j] = 0.0f;
} else {
res.data->flex[j] = y_val * powf(x_val, y_val - 1.0f);
}
}
return res;
} else {
// Gradient w.r.t. y: x^y * ln(x)
Tensor res = Tensor_new(y.shape, false);
for (int j = 0; j < res.data->numel; j++) {
float x_val = x.data->flex[j];
float y_val = y.data->flex[j];
float x_val = x.data->flex[j % x.data->numel];
float self_val = self.data->flex[j];
if (x_val <= 0.0f) {
// Gradient of x^y w.r.t y is undefined or complex for x <= 0.
// Returning 0 for simplicity, but this might need specific handling depending on use case.
Expand All @@ -412,26 +409,28 @@ static Tensor GradFn_pow(Tensor self, int i) {
// A robust solution might involve checking domain or returning NaN.
res.data->flex[j] = 0.0f;
} else {
res.data->flex[j] = powf(x_val, y_val) * logf(x_val);
res.data->flex[j] = self_val * logf(x_val);
}
}
return res;
}
return res;
}

Tensor Tensor_pow(Tensor self, Tensor other) {
Tensor orig_self = self;
Tensor orig_other = other;
if (!cten_elemwise_broadcast(&self, &other)) {
cten_assert_shape("Tensor_pow() cannot broadcast", self.shape, other.shape);
cten_assert_shape("Tensor_pow() cannot broadcast", orig_self.shape, orig_other.shape);
}
bool requires_grad = !cten_is_eval() && (self.node != NULL || other.node != NULL);
bool requires_grad = !cten_is_eval() && (orig_self.node != NULL || orig_other.node != NULL);
Tensor res = Tensor_new(self.shape, requires_grad);
for (int i = 0; i < self.data->numel; i++) {
res.data->flex[i] = powf(self.data->flex[i], other.data->flex[i]);
}
if (requires_grad) {
res.node->grad_fn = GradFn_pow;
res.node->inputs[0] = self;
res.node->inputs[1] = other;
res.node->inputs[0] = orig_self;
res.node->inputs[1] = orig_other;
res.node->n_inputs = 2;
res.node->name = "Pow";
}
Expand Down
166 changes: 166 additions & 0 deletions tests/Backward/test_div_backward.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
#include "../../include/cten.h"
#include "../test_utils.h"
#include "../csv_reporter.h"
#include "../test_config.h"
#include <stdio.h>

void test_div_backward() {
const char* op_name = "div_backward";
PoolId pool_id = 0;
cten_begin_malloc(pool_id);

// Test Case 1: Simple element-wise vector division
{
const char* tc_name = "div_vectors_backward";
TensorShape shape = {3};
float x_data[] = {6.7548f, 3.4753f, -7.6282f};
float y_data[] = {4.5687f, 2.6877f, -1.8746f};

// z = x / y = [1.4785, 1.2930, 4.0692]
// loss = sum(z) = 6.8407
float exp_grad_x[] = {0.218881f, 0.372065f, -0.533447f};
float exp_grad_y[] = {-0.323614f, -0.481095f, 2.170725f};

Tensor x = create_test_tensor(shape, x_data, true);
Tensor y = create_test_tensor(shape, y_data, true);

Tensor z = Tensor_div(x, y);
Tensor loss = Tensor_sum(z);

Tensor grad_dummy = {0};
Tensor_backward(loss, grad_dummy);

Tensor expected_grad_x = create_test_tensor(shape, exp_grad_x, false);
Tensor expected_grad_y = create_test_tensor(shape, exp_grad_y, false);

compare_tensors(&x.node->grad, &expected_grad_x, op_name, tc_name, 1, TEST_FLOAT_TOLERANCE);
compare_tensors(&y.node->grad, &expected_grad_y, op_name, tc_name, 2, TEST_FLOAT_TOLERANCE);
}

// Test Case 2: Broadcasting a vector by a scalar
{
const char* tc_name = "div_broadcast_vec_scalar_backward";
TensorShape x_shape = {2};
TensorShape y_shape = {1};
float x_data[] = {1.2388f, -6.849f};
float y_data[] = {-1.8818f};

// z = x / y = [-0.6583, 3.6396]
// loss = sum(z) = 2.9813
float exp_grad_x[] = {-0.531406f, -0.531406f};
float exp_grad_y[] = {1.584278f};

Tensor x = create_test_tensor(x_shape, x_data, true);
Tensor y = create_test_tensor(y_shape, y_data, true);

Tensor z = Tensor_div(x, y);
Tensor loss = Tensor_sum(z);

Tensor grad_dummy = {0};
Tensor_backward(loss, grad_dummy);

Tensor expected_grad_x = create_test_tensor(x_shape, exp_grad_x, false);
Tensor expected_grad_y = create_test_tensor(y_shape, exp_grad_y, false);

compare_tensors(&x.node->grad, &expected_grad_x, op_name, tc_name, 1, TEST_FLOAT_TOLERANCE);
compare_tensors(&y.node->grad, &expected_grad_y, op_name, tc_name, 2, TEST_FLOAT_TOLERANCE);
}

// Test Case 3: Broadcasting a scalar by a vector
{
const char* tc_name = "div_broadcast_scalar_vec_backward";
TensorShape x_shape = {1};
TensorShape y_shape = {3};
float x_data[] = {8.2849f};
float y_data[] = {-4.2233f, 2.361f, 4.8289f};

// z = x / y = [-1.9617, 3.5091, 1.7157]
// loss = sum(z) = 3.2631
float exp_grad_x[] = {0.393854f};
float exp_grad_y[] = {-0.464498f, -1.486262f, -0.355296f};

Tensor x = create_test_tensor(x_shape, x_data, true);
Tensor y = create_test_tensor(y_shape, y_data, true);

Tensor z = Tensor_div(x, y);
Tensor loss = Tensor_sum(z);

Tensor grad_dummy = {0};
Tensor_backward(loss, grad_dummy);

Tensor expected_grad_x = create_test_tensor(x_shape, exp_grad_x, false);
Tensor expected_grad_y = create_test_tensor(y_shape, exp_grad_y, false);

compare_tensors(&x.node->grad, &expected_grad_x, op_name, tc_name, 1, TEST_FLOAT_TOLERANCE);
compare_tensors(&y.node->grad, &expected_grad_y, op_name, tc_name, 2, TEST_FLOAT_TOLERANCE);
}

// Test Case 4: Matrix division with negative values
{
const char* tc_name = "div_matrices_neg_vals_backward";
TensorShape shape = {2, 2};
float x_data[] = {1.8347f, -8.6274f, -8.2642f, -5.8261f};
float y_data[] = {-2.5141f, -4.3176f, -4.4468f, 3.8183f};

// z = x / y = [-0.7298, 1.9982, 1.8585, -1.5258]
// loss = sum(z) = 1.6011
float exp_grad_x[] = {-0.397757f, -0.23161f, -0.224881f, 0.261897f};
float exp_grad_y[] = {-0.290269f, 0.462802f, 0.417932f, 0.399611f};

Tensor x = create_test_tensor(shape, x_data, true);
Tensor y = create_test_tensor(shape, y_data, true);

Tensor z = Tensor_div(x, y);
Tensor loss = Tensor_sum(z);

Tensor grad_dummy = {0};
Tensor_backward(loss, grad_dummy);

Tensor expected_grad_x = create_test_tensor(shape, exp_grad_x, false);
Tensor expected_grad_y = create_test_tensor(shape, exp_grad_y, false);

compare_tensors(&x.node->grad, &expected_grad_x, op_name, tc_name, 1, TEST_FLOAT_TOLERANCE);
compare_tensors(&y.node->grad, &expected_grad_y, op_name, tc_name, 2, TEST_FLOAT_TOLERANCE);
}

// Test Case 5: Complex computation graph (z = (a/b) * c)
{
const char* tc_name = "div_complex_graph_backward";
TensorShape shape = {1};
float a_data[] = {3.0511f};
float b_data[] = {1.3192f};
float c_data[] = {1.404f};

// Let d = a / b. Then z = d * c.
// Forward: d = 3.0511/1.3192 = 2.3129. z = 2.3129 * 1.404 = 3.2472
// Backward pass:
// dz/dc = d = 2.312841
float exp_grad_c[] = {2.312841f};

// dz/d(d) = c = 1.404 (This is the upstream gradient for the div op)
// dz/da = (dz/dd) * (dd/da) = c * (1/b) = 1.404 * (1/1.3192) = 1.064281
float exp_grad_a[] = {1.064281f};
// dz/db = (dz/dd) * (dd/db) = c * (-a/b²) = 1.404 * (-3.0511/(1.3192*1.3192)) = -2.461514
float exp_grad_b[] = {-2.461514f};

Tensor a = create_test_tensor(shape, a_data, true);
Tensor b = create_test_tensor(shape, b_data, true);
Tensor c = create_test_tensor(shape, c_data, true);

Tensor d = Tensor_div(a, b);
Tensor z = Tensor_mul(d, c);

Tensor grad_dummy = {0};
Tensor_backward(z, grad_dummy);

Tensor expected_grad_a_tensor = create_test_tensor(shape, exp_grad_a, false);
Tensor expected_grad_b_tensor = create_test_tensor(shape, exp_grad_b, false);
Tensor expected_grad_c_tensor = create_test_tensor(shape, exp_grad_c, false);

compare_tensors(&a.node->grad, &expected_grad_a_tensor, op_name, tc_name, 1, TEST_FLOAT_TOLERANCE);
compare_tensors(&b.node->grad, &expected_grad_b_tensor, op_name, tc_name, 2, TEST_FLOAT_TOLERANCE);
compare_tensors(&c.node->grad, &expected_grad_c_tensor, op_name, tc_name, 3, TEST_FLOAT_TOLERANCE);
}

cten_free(pool_id);
}
Loading