Skip to content

Commit cd2a310

Browse files
authored
Merge pull request #24 from Advaitgaur004/backward-3
Fix: Stabilize Backpropagation for Sum and Mean
2 parents 451cc3f + b707ee5 commit cd2a310

File tree

9 files changed

+705
-72
lines changed

9 files changed

+705
-72
lines changed

include/cten.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,4 +134,5 @@ void cten_assert_dim(const char* title, int a, int b);
134134
bool cten_elemwise_broadcast(Tensor* a, Tensor* b);
135135
int load_iris_dataset(const float (**X)[4], const int** y);
136136
Tensor Tensor_reduce_dim(Tensor self, int dim, const char* operation);
137-
Tensor reduce_gradient_for_broadcasting(Tensor grad, TensorShape original_shape, TensorShape broadcasted_shape);
137+
Tensor reduce_gradient_for_broadcasting(Tensor grad, TensorShape original_shape, TensorShape broadcasted_shape);
138+
Tensor Tensor_unsqueeze(Tensor self, int dim);

src/basic.c

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,32 @@ void Tensor_backward(Tensor self, Tensor grad) {
149149
// Step 1: Get the local gradient (the partial derivative). --> For z = f(x, y), this would be dz/dx or dz/dy.
150150
Tensor input_grad = self.node->grad_fn(self, i);
151151

152-
// Step 2: Apply the chain rule. --> The gradient flowing to the input is upstream_grad * local_grad.
152+
// This is the gradient flowing from the output, which we need to propagate backwards.
153+
Tensor grad = self.node->grad;
154+
int input_ndim = TensorShape_dim(input_tensor.shape);
155+
int grad_ndim = TensorShape_dim(grad.shape);
156+
157+
if ((strcmp(self.node->name, "Sum") == 0 || strcmp(self.node->name, "Mean") == 0) && input_ndim > grad_ndim) {
158+
// Find the dimension that was reduced. We assume the non-reduced dimensions match in size.
159+
int unsqueeze_dim = -1;
160+
int grad_idx = 0;
161+
for (int dim_idx = 0; dim_idx < input_ndim; ++dim_idx) {
162+
if (grad_idx >= grad_ndim || input_tensor.shape[dim_idx] != grad.shape[grad_idx]) {
163+
// Yes, this is the dimension that was removed.
164+
unsqueeze_dim = dim_idx;
165+
break;
166+
}
167+
grad_idx++;
168+
}
169+
170+
if (unsqueeze_dim != -1) {
171+
grad = Tensor_unsqueeze(grad, unsqueeze_dim);
172+
} else {
173+
cten_assert(false, "Could not deduce unsqueeze dimension.");
174+
}
175+
}
176+
177+
// Step 2: Apply the chain rule (upstream_grad * local_grad)
153178
Tensor combined_grad;
154179
if(strcmp(self.node->name, "Matmul") == 0) {
155180
if (i == 0) {

src/operator.c

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,40 @@ void Tensor_argmax(Tensor self, int* out) {
105105
}
106106

107107
Tensor GradFn_mean(Tensor self, int i) {
108-
// f(x) = mean(x); f'(x) = 1 / x.numel()
109-
Tensor res = Tensor_new(self.shape, false);
110-
for(int i = 0; i < res.data->numel; i++) {
111-
res.data->flex[i] = 1.0f / self.data->numel;
108+
Tensor input_tensor = self.node->inputs[i];
109+
int divisor;
110+
111+
if (TensorShape_numel(self.shape) == 1 && TensorShape_numel(input_tensor.shape) > 1) {
112+
divisor = TensorShape_numel(input_tensor.shape);
113+
} else {
114+
int input_ndim = TensorShape_dim(input_tensor.shape);
115+
int output_ndim = TensorShape_dim(self.shape);
116+
if (input_ndim > output_ndim) {
117+
int out_idx = 0;
118+
int reduced_dim_size = 1;
119+
for(int d=0; d < input_ndim; ++d) {
120+
if(out_idx >= output_ndim || input_tensor.shape[d] != self.shape[out_idx]) {
121+
reduced_dim_size = input_tensor.shape[d];
122+
break;
123+
}
124+
out_idx++;
125+
}
126+
divisor = reduced_dim_size;
127+
} else {
128+
// scalar input
129+
divisor = TensorShape_numel(input_tensor.shape);
130+
}
112131
}
132+
133+
// gradient ==> SAME SHAPE as the ORIGINAL INPUT.
134+
Tensor res = Tensor_new(input_tensor.shape, false);
135+
136+
// gradient value is 1 divided by the number of elements that were averaged.
137+
float grad_val = 1.0f / divisor;
138+
139+
for(int j = 0; j < res.data->numel; j++) {
140+
res.data->flex[j] = grad_val;
141+
}
113142
return res;
114143
}
115144

src/utils.c

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,5 +342,28 @@ Tensor Tensor_reduce_dim(Tensor self, int dim, const char* operation) {
342342
}
343343
}
344344

345+
return res;
346+
}
347+
348+
Tensor Tensor_unsqueeze(Tensor self, int dim) {
349+
int old_ndim = TensorShape_dim(self.shape);
350+
cten_assert(dim >= 0 && dim <= old_ndim, "Unsqueeze dim out of bounds");
351+
352+
TensorShape new_shape = {0};
353+
int old_idx = 0;
354+
// insert a '1' at the 'dim' position in the new shape.
355+
for (int i = 0; i < old_ndim + 1 && i < 4; i++) {
356+
if (i == dim) {
357+
new_shape[i] = 1;
358+
} else {
359+
if(old_idx < 4) {
360+
new_shape[i] = self.shape[old_idx++];
361+
}
362+
}
363+
}
364+
365+
Tensor res = self;
366+
memcpy(res.shape, new_shape, sizeof(TensorShape));
367+
345368
return res;
346369
}

tests/Backward/test_linear_backward.c

Lines changed: 44 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -163,66 +163,65 @@ void test_linear_backward() {
163163
}
164164
}
165165

166-
// TODO: Tensor_sum and Tensor_mean backward is in working progress
167-
// // Test Case 4: Chained operations with linear
168-
// {
169-
// const char* tc_name = "Chained_operations_with_linear";
170-
// // Sub-test 1: Linear followed by sum
171-
// {
172-
// TensorShape input_shape = {2, 3}; // batch_size=2, input_features=3
173-
// TensorShape weight_shape = {3, 2}; // input_features=3, output_features=2
174-
// TensorShape bias_shape = {1, 2}; // output_features=2
166+
// Test Case 4: Chained operations with linear
167+
{
168+
const char* tc_name = "Chained_operations_with_linear";
169+
// Sub-test 1: Linear followed by sum
170+
{
171+
TensorShape input_shape = {2, 3}; // batch_size=2, input_features=3
172+
TensorShape weight_shape = {3, 2}; // input_features=3, output_features=2
173+
TensorShape bias_shape = {1, 2}; // output_features=2
175174

176-
// float input_data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
177-
// float weight_data[] = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f};
178-
// float bias_data[] = {0.1f, 0.2f};
175+
float input_data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
176+
float weight_data[] = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f};
177+
float bias_data[] = {0.1f, 0.2f};
179178

180-
// // Expected gradients
181-
// float exp_grad_bias[] = {1.0f, 1.0f}; // For sum reduction
179+
// Expected gradients
180+
float exp_grad_bias[] = {2.0f, 2.0f}; // For sum reduction
182181

183-
// Tensor input = create_test_tensor(input_shape, input_data, true);
184-
// Tensor weight = create_test_tensor(weight_shape, weight_data, true);
185-
// Tensor bias = create_test_tensor(bias_shape, bias_data, true);
182+
Tensor input = create_test_tensor(input_shape, input_data, true);
183+
Tensor weight = create_test_tensor(weight_shape, weight_data, true);
184+
Tensor bias = create_test_tensor(bias_shape, bias_data, true);
186185

187-
// Tensor output = nn_linear(input, weight, bias);
188-
// Tensor sum_output = Tensor_sum(output);
186+
Tensor output = nn_linear(input, weight, bias);
187+
Tensor sum_output = Tensor_sum(output);
189188

190-
// Tensor_backward(sum_output, (Tensor){0});
189+
Tensor_backward(sum_output, (Tensor){0});
191190

192-
// Tensor expected_grad_bias = create_test_tensor(bias_shape, exp_grad_bias, false);
191+
Tensor expected_grad_bias = create_test_tensor(bias_shape, exp_grad_bias, false);
193192

194-
// // Focus on bias gradient
195-
// compare_tensors(&bias.node->grad, &expected_grad_bias, op_name, tc_name, 1, TEST_FLOAT_TOLERANCE);
196-
// }
193+
// Focus on bias gradient
194+
compare_tensors(&bias.node->grad, &expected_grad_bias, op_name, tc_name, 1, TEST_FLOAT_TOLERANCE);
195+
}
197196

198-
// // Sub-test 2: Linear followed by mean
199-
// {
200-
// TensorShape input_shape = {2, 3}; // batch_size=2, input_features=3
201-
// TensorShape weight_shape = {3, 2}; // input_features=3, output_features=2
202-
// TensorShape bias_shape = {1, 2}; // output_features=2
197+
// Sub-test 2: Linear followed by mean
198+
{
199+
TensorShape input_shape = {2, 3}; // batch_size=2, input_features=3
200+
TensorShape weight_shape = {3, 2}; // input_features=3, output_features=2
201+
TensorShape bias_shape = {1, 2}; // output_features=2
203202

204-
// float input_data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
205-
// float weight_data[] = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f};
206-
// float bias_data[] = {0.1f, 0.2f};
203+
float input_data[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
204+
float weight_data[] = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f};
205+
float bias_data[] = {0.1f, 0.2f};
207206

208-
// // Expected gradients
209-
// float exp_grad_bias[] = {0.25f, 0.25f}; // For mean reduction (1/4)
207+
// Expected gradients
208+
float exp_grad_bias[] = {0.5f, 0.5f}; // For mean reduction (1/2)
210209

211-
// Tensor input = create_test_tensor(input_shape, input_data, true);
212-
// Tensor weight = create_test_tensor(weight_shape, weight_data, true);
213-
// Tensor bias = create_test_tensor(bias_shape, bias_data, true);
210+
Tensor input = create_test_tensor(input_shape, input_data, true);
211+
Tensor weight = create_test_tensor(weight_shape, weight_data, true);
212+
Tensor bias = create_test_tensor(bias_shape, bias_data, true);
214213

215-
// Tensor output = nn_linear(input, weight, bias);
216-
// Tensor mean_output = Tensor_mean(output);
214+
Tensor output = nn_linear(input, weight, bias);
215+
Tensor mean_output = Tensor_mean(output);
217216

218-
// Tensor_backward(mean_output, (Tensor){0});
217+
Tensor_backward(mean_output, (Tensor){0});
219218

220-
// Tensor expected_grad_bias = create_test_tensor(bias_shape, exp_grad_bias, false);
219+
Tensor expected_grad_bias = create_test_tensor(bias_shape, exp_grad_bias, false);
221220

222-
// // Focus on bias gradient
223-
// compare_tensors(&bias.node->grad, &expected_grad_bias, op_name, tc_name, 2, TEST_FLOAT_TOLERANCE);
224-
// }
225-
// }
221+
// Focus on bias gradient
222+
compare_tensors(&bias.node->grad, &expected_grad_bias, op_name, tc_name, 2, TEST_FLOAT_TOLERANCE);
223+
}
224+
}
226225

227226
cten_free(pool_id);
228227
}

0 commit comments

Comments
 (0)