Skip to content

Commit 5e07b21

Browse files
authored
Merge pull request #23 from Advaitgaur004/reduction-operator
Add: autograd support for min and max tensor operations
2 parents cd2a310 + a11826e commit 5e07b21

File tree

9 files changed

+939
-8
lines changed

9 files changed

+939
-8
lines changed

include/cten.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
#include <stdarg.h>
88
#include <limits.h>
99

10+
#define _CTEN_PICK_REDUCE(_1, _2, NAME, ...) NAME
11+
#define Tensor_max(...) _CTEN_PICK_REDUCE(__VA_ARGS__, Tensor_max_dim, Tensor_max_all)(__VA_ARGS__)
12+
#define Tensor_min(...) _CTEN_PICK_REDUCE(__VA_ARGS__, Tensor_min_dim, Tensor_min_all)(__VA_ARGS__)
13+
1014
#define _CTEN_PICK(_1,_2,NAME,...) NAME
1115
#define Tensor_mean(...) _CTEN_PICK(__VA_ARGS__, Tensor_mean_dim, Tensor_mean_all)(__VA_ARGS__)
1216
#define Tensor_sum(...) _CTEN_PICK(__VA_ARGS__, Tensor_sum_dim, Tensor_sum_all )(__VA_ARGS__)
@@ -33,6 +37,11 @@ typedef struct GradNode {
3337
const char* name;
3438
} GradNode;
3539

40+
typedef struct {
41+
Tensor values;
42+
Tensor indices;
43+
} TensorMaxMinResult;
44+
3645
void cten_initilize();
3746
void cten_finalize();
3847

@@ -81,8 +90,10 @@ Tensor Tensor_mean_dim(Tensor self, int dim);
8190
Tensor Tensor_sum_all (Tensor self);
8291
Tensor Tensor_sum_dim (Tensor self, int dim);
8392

84-
Tensor Tensor_max(Tensor self);
85-
Tensor Tensor_min(Tensor self);
93+
Tensor Tensor_max_all(Tensor self);
94+
TensorMaxMinResult Tensor_max_dim(Tensor self, int dim);
95+
Tensor Tensor_min_all(Tensor self);
96+
TensorMaxMinResult Tensor_min_dim(Tensor self, int dim);
8697

8798
void Tensor_argmax(Tensor self, int* out);
8899

src/basic.c

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,11 @@ void Tensor_backward(Tensor self, Tensor grad) {
140140
}
141141

142142
for(int i = 0; i < self.node->n_inputs; i++) {
143-
if (self.node->inputs[i].data == NULL) {
143+
Tensor input_tensor = self.node->inputs[i];
144+
if (input_tensor.node == NULL) {
144145
continue;
145146
}
146147

147-
Tensor input_tensor = self.node->inputs[i];
148-
149148
// Step 1: Get the local gradient (the partial derivative). --> For z = f(x, y), this would be dz/dx or dz/dy.
150149
Tensor input_grad = self.node->grad_fn(self, i);
151150

@@ -154,7 +153,7 @@ void Tensor_backward(Tensor self, Tensor grad) {
154153
int input_ndim = TensorShape_dim(input_tensor.shape);
155154
int grad_ndim = TensorShape_dim(grad.shape);
156155

157-
if ((strcmp(self.node->name, "Sum") == 0 || strcmp(self.node->name, "Mean") == 0) && input_ndim > grad_ndim) {
156+
if ((strcmp(self.node->name, "Sum") == 0 || strcmp(self.node->name, "Mean") == 0 || strcmp(self.node->name, "MaxDim") == 0 || strcmp(self.node->name, "MinDim") == 0) && input_ndim > grad_ndim) {
158157
// Find the dimension that was reduced. We assume the non-reduced dimensions match in size.
159158
int unsqueeze_dim = -1;
160159
int grad_idx = 0;

src/operator.c

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
#ifdef Tensor_sum
1515
#undef Tensor_sum
1616
#endif
17+
#ifdef Tensor_max
18+
#undef Tensor_max
19+
#endif
20+
#ifdef Tensor_min
21+
#undef Tensor_min
22+
#endif
1723

1824
static Tensor GradFn_add(Tensor self, int i) {
1925
// f(x, y) = x + y; f'(x) = 1; f'(y) = 1
@@ -450,5 +456,131 @@ Tensor Tensor_sub(Tensor self, Tensor other) {
450456
res.node->n_inputs = 2;
451457
res.node->name = "Sub";
452458
}
459+
return res;
460+
}
461+
462+
Tensor GradFn_reduce_dim(Tensor self, int i) {
463+
Tensor input = self.node->inputs[0];
464+
Tensor indices_tensor = self.node->inputs[1];
465+
Tensor grad_out = Tensor_zeros(input.shape, false);
466+
467+
int out_numel = indices_tensor.data->numel;
468+
int ndim = TensorShape_dim(input.shape);
469+
int reduced_dim = -1;
470+
471+
for(int d = 0, out_d = 0; d < ndim; d++){
472+
if(out_d >= TensorShape_dim(self.shape) || input.shape[d] != self.shape[out_d]){
473+
reduced_dim = d;
474+
break;
475+
}
476+
out_d++;
477+
}
478+
cten_assert(reduced_dim != -1, "Could not determine reduced dimension in gradient calculation");
479+
480+
for (int j = 0; j < out_numel; j++) {
481+
int index_along_dim = (int)indices_tensor.data->flex[j];
482+
483+
int linear_idx = 0, stride = 1, out_j_rem = j, out_shape_idx = TensorShape_dim(self.shape) - 1;
484+
for (int k = ndim - 1; k >= 0; --k) {
485+
int current_dim_idx;
486+
if (k == reduced_dim) {
487+
current_dim_idx = index_along_dim;
488+
} else {
489+
int dim_k = self.shape[out_shape_idx--];
490+
current_dim_idx = out_j_rem % dim_k;
491+
out_j_rem /= dim_k;
492+
}
493+
linear_idx += current_dim_idx * stride;
494+
stride *= input.shape[k];
495+
}
496+
grad_out.data->flex[linear_idx] = 1.0f;
497+
}
498+
return grad_out;
499+
}
500+
501+
Tensor GradFn_max_all(Tensor self, int i) {
502+
Tensor input = self.node->inputs[i];
503+
Tensor res = Tensor_zeros(input.shape, false);
504+
float max_val = self.data->flex[0];
505+
506+
int max_count = 0;
507+
for (int j = 0; j < input.data->numel; j++) {
508+
if (input.data->flex[j] == max_val) max_count++;
509+
}
510+
511+
float grad_value = (max_count > 0) ? 1.0f / max_count : 0.0f;
512+
for (int j = 0; j < input.data->numel; j++) {
513+
if (input.data->flex[j] == max_val) res.data->flex[j] = grad_value;
514+
}
515+
return res;
516+
}
517+
518+
Tensor Tensor_max(Tensor self) {
519+
if (self.data->numel == 0){
520+
cten_assert(false, "Error: max() on an empty tensor.");
521+
}
522+
bool requires_grad = !cten_is_eval() && (self.node != NULL);
523+
Tensor res = Tensor_new((TensorShape){1, 0, 0, 0}, requires_grad);
524+
525+
float max_val = self.data->flex[0];
526+
for (int i = 1; i < self.data->numel; i++) {
527+
if (self.data->flex[i] > max_val) {
528+
max_val = self.data->flex[i];
529+
}
530+
}
531+
532+
res.data->flex[0] = max_val;
533+
534+
if (requires_grad) {
535+
res.node->grad_fn = GradFn_max_all;
536+
res.node->inputs[0] = self;
537+
res.node->n_inputs = 1;
538+
res.node->name = "MaxAll";
539+
}
540+
541+
return res;
542+
}
543+
544+
Tensor GradFn_min_all(Tensor self, int i) {
545+
Tensor input = self.node->inputs[i];
546+
Tensor res = Tensor_zeros(input.shape, false);
547+
float min_val = self.data->flex[0];
548+
549+
int min_count = 0;
550+
for (int j = 0; j < input.data->numel; j++) {
551+
if (input.data->flex[j] == min_val) min_count++;
552+
}
553+
554+
float grad_value = (min_count > 0) ? 1.0f / min_count : 0.0f;
555+
for (int j = 0; j < input.data->numel; j++) {
556+
if (input.data->flex[j] == min_val) res.data->flex[j] = grad_value;
557+
}
558+
return res;
559+
}
560+
561+
Tensor Tensor_min(Tensor self) {
562+
if (self.data->numel == 0){
563+
cten_assert(false, "Error: min() on an empty tensor.");
564+
}
565+
bool requires_grad = !cten_is_eval() && (self.node != NULL);
566+
Tensor res = Tensor_new((TensorShape){1, 0, 0, 0}, requires_grad);
567+
568+
// Find minimum value
569+
float min_val = self.data->flex[0];
570+
for (int i = 1; i < self.data->numel; i++) {
571+
if (self.data->flex[i] < min_val) {
572+
min_val = self.data->flex[i];
573+
}
574+
}
575+
576+
res.data->flex[0] = min_val;
577+
578+
if (requires_grad) {
579+
res.node->grad_fn = GradFn_min_all;
580+
res.node->inputs[0] = self;
581+
res.node->n_inputs = 1;
582+
res.node->name = "MinAll";
583+
}
584+
453585
return res;
454586
}

src/utils.c

Lines changed: 153 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ bool va_arg_is_present(va_list args) {
1616

1717
Tensor GradFn_mean(Tensor self, int i);
1818
Tensor GradFn_sum(Tensor self, int i);
19+
Tensor GradFn_max_all(Tensor self, int i);
20+
Tensor GradFn_min_all(Tensor self, int i);
21+
Tensor GradFn_reduce_dim(Tensor self, int i);
1922

2023
Tensor Tensor_mean_all(Tensor self) {
2124
float total = 0.0f;
@@ -67,6 +70,155 @@ Tensor Tensor_sum_dim(Tensor self, int dim) {
6770
return res;
6871
}
6972

73+
Tensor Tensor_max_all(Tensor self) {
74+
bool requires_grad = !cten_is_eval() && (self.node != NULL);
75+
Tensor res = Tensor_new((TensorShape){1, 0, 0, 0}, requires_grad);
76+
77+
if (self.data->numel == 0) cten_assert(false, "max on empty tensor");
78+
float max_val = self.data->flex[0];
79+
for (int i = 1; i < self.data->numel; i++) {
80+
if (self.data->flex[i] > max_val) {
81+
max_val = self.data->flex[i];
82+
}
83+
}
84+
res.data->flex[0] = max_val;
85+
86+
if (requires_grad) {
87+
res.node->grad_fn = GradFn_max_all;
88+
res.node->inputs[0] = self;
89+
res.node->n_inputs = 1;
90+
res.node->name = "MaxAll";
91+
}
92+
return res;
93+
}
94+
95+
TensorMaxMinResult Tensor_max_dim(Tensor self, int dim) {
96+
int ndim = TensorShape_dim(self.shape);
97+
dim = TensorShape_asdim(self.shape, dim);
98+
99+
TensorShape out_shape = {0};
100+
int out_shape_len = 0;
101+
for (int i = 0; i < ndim; i++) {
102+
if (i != dim) out_shape[out_shape_len++] = self.shape[i];
103+
}
104+
105+
bool requires_grad = !cten_is_eval() && (self.node != NULL);
106+
Tensor values = Tensor_new(out_shape, requires_grad);
107+
Tensor indices = Tensor_new(out_shape, false);
108+
109+
int dim_size = self.shape[dim];
110+
for (int i = 0; i < values.data->numel; ++i) {
111+
float best_val = -INFINITY;
112+
int best_idx = -1;
113+
114+
for (int j = 0; j < dim_size; ++j) {
115+
int in_linear_idx = 0, stride = 1, out_i_rem = i, out_idx_tracker = out_shape_len - 1;
116+
for (int k = ndim - 1; k >= 0; --k) {
117+
int current_dim_idx;
118+
if (k == dim) {
119+
current_dim_idx = j;
120+
} else {
121+
int dim_k = out_shape[out_idx_tracker--];
122+
current_dim_idx = out_i_rem % dim_k;
123+
out_i_rem /= dim_k;
124+
}
125+
in_linear_idx += current_dim_idx * stride;
126+
stride *= self.shape[k];
127+
}
128+
float current_val = self.data->flex[in_linear_idx];
129+
if (current_val > best_val) { best_val = current_val; best_idx = j; }
130+
}
131+
values.data->flex[i] = best_val;
132+
indices.data->flex[i] = (float)best_idx;
133+
}
134+
135+
if (requires_grad) {
136+
values.node->grad_fn = GradFn_reduce_dim;
137+
values.node->inputs[0] = self;
138+
values.node->inputs[1] = indices;
139+
values.node->n_inputs = 2;
140+
values.node->name = "MaxDim";
141+
}
142+
143+
TensorMaxMinResult result = {values, indices};
144+
return result;
145+
}
146+
147+
Tensor Tensor_min_all(Tensor self) {
148+
bool requires_grad = !cten_is_eval() && (self.node != NULL);
149+
Tensor res = Tensor_new((TensorShape){1, 0, 0, 0}, requires_grad);
150+
151+
if (self.data->numel == 0) cten_assert(false, "min on empty tensor");
152+
float min_val = self.data->flex[0];
153+
for (int i = 1; i < self.data->numel; i++) {
154+
if (self.data->flex[i] < min_val) {
155+
min_val = self.data->flex[i];
156+
}
157+
}
158+
res.data->flex[0] = min_val;
159+
160+
if (requires_grad) {
161+
res.node->grad_fn = GradFn_min_all;
162+
res.node->inputs[0] = self;
163+
res.node->n_inputs = 1;
164+
res.node->name = "MinAll";
165+
}
166+
return res;
167+
}
168+
169+
TensorMaxMinResult Tensor_min_dim(Tensor self, int dim) {
170+
int ndim = TensorShape_dim(self.shape);
171+
dim = TensorShape_asdim(self.shape, dim);
172+
173+
TensorShape out_shape = {0};
174+
int out_shape_len = 0;
175+
for (int i = 0; i < ndim; i++) {
176+
if (i != dim) out_shape[out_shape_len++] = self.shape[i];
177+
}
178+
179+
bool requires_grad = !cten_is_eval() && (self.node != NULL);
180+
Tensor values = Tensor_new(out_shape, requires_grad);
181+
Tensor indices = Tensor_new(out_shape, false);
182+
183+
int dim_size = self.shape[dim];
184+
for (int i = 0; i < values.data->numel; ++i) {
185+
float best_val = INFINITY;
186+
int best_idx = -1;
187+
188+
for (int j = 0; j < dim_size; ++j) {
189+
int in_linear_idx = 0, stride = 1, out_i_rem = i, out_idx_tracker = out_shape_len - 1;
190+
for (int k = ndim - 1; k >= 0; --k) {
191+
int current_dim_idx;
192+
if (k == dim) {
193+
current_dim_idx = j;
194+
} else {
195+
int dim_k = out_shape[out_idx_tracker--];
196+
current_dim_idx = out_i_rem % dim_k;
197+
out_i_rem /= dim_k;
198+
}
199+
in_linear_idx += current_dim_idx * stride;
200+
stride *= self.shape[k];
201+
}
202+
float current_val = self.data->flex[in_linear_idx];
203+
if (current_val < best_val) { best_val = current_val; best_idx = j; }
204+
}
205+
values.data->flex[i] = best_val;
206+
indices.data->flex[i] = (float)best_idx;
207+
}
208+
209+
if (requires_grad) {
210+
values.node->grad_fn = GradFn_reduce_dim;
211+
values.node->inputs[0] = self;
212+
values.node->inputs[1] = indices;
213+
values.node->n_inputs = 2;
214+
values.node->name = "MinDim";
215+
}
216+
217+
TensorMaxMinResult result = {values, indices};
218+
return result;
219+
}
220+
221+
70222
void cten_assert(bool cond, const char* fmt, ...) {
71223
if(!cond) {
72224
va_list args;
@@ -91,7 +243,6 @@ void cten_assert_dim(const char* title, int a, int b) {
91243
cten_assert(a == b, "%s: %d != %d", title, a, b);
92244
}
93245

94-
95246
bool cten_elemwise_broadcast(Tensor* a, Tensor* b) {
96247
Tensor orig_a = *a;
97248
Tensor orig_b = *b;
@@ -366,4 +517,4 @@ Tensor Tensor_unsqueeze(Tensor self, int dim) {
366517
memcpy(res.shape, new_shape, sizeof(TensorShape));
367518

368519
return res;
369-
}
520+
}

0 commit comments

Comments
 (0)