Skip to content

Commit 26f7e08

Browse files
committed
Merged functions, fixed the errors in the workflow.
1 parent 67e87d8 commit 26f7e08

File tree

4 files changed

+107
-132
lines changed

4 files changed

+107
-132
lines changed

include/cten.h

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -273,23 +273,6 @@ Tensor Tensor_divf(Tensor self, float other);
273273
*/
274274
Tensor Tensor_powf(Tensor self, float other);
275275

276-
/**
277-
* @brief Performs batch matrix multiplication for two 3D tensors.
278-
* For each batch index, multiplies the corresponding {m, n} and {n, p} matrices:
279-
* - self: shape {batch, m, n}
280-
* - other: shape {batch, n, p}
281-
* Returns a tensor of shape {batch, m, p} where each slice is the matrix product of the input slices.
282-
* Only supports strictly matched batch sizes and no broadcasting.
283-
* Each batch slice is extracted using Tensor_batch_slice, and standard Tensor_matmul is applied.
284-
* Prints the dimensions for each batch multiplication for debugging.
285-
* The output tensor contains all resulting batch matrix products.
286-
*
287-
* @param self Input tensor of shape {batch, m, n}
288-
* @param other Input tensor of shape {batch, n, p}
289-
* @return Output tensor of shape {batch, m, p} with the results of all batch multiplications
290-
*/
291-
Tensor Tensor_matmul_batch(Tensor self, Tensor other);
292-
293276
/**
294277
* @brief Matrix multiplication of two tensors
295278
* @param self First tensor (left operand)

src/operator.c

Lines changed: 57 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -229,126 +229,89 @@ static Tensor GradFn_matmul(Tensor self, int i) {
229229
return Tensor_transpose(Tensor_detach(self.node->inputs[1 - i]));
230230
}
231231

232-
Tensor Tensor_batch_slice(Tensor t, int batch_idx, int group_idx) {
233-
int dim = TensorShape_dim(t.shape);
234-
235-
int m, n, offset;
236-
TensorShape slice_shape = {0, 0, 0, 0};
237-
238-
if (dim == 3) {
239-
int b = t.shape[0]; m = t.shape[1]; n = t.shape[2];
240-
assert(batch_idx >= 0 && batch_idx < b);
241-
242-
offset = batch_idx * m * n;
243-
slice_shape[0] = m; slice_shape[1] = n;
244-
} else if (dim == 4) {
245-
int b = t.shape[0], g = t.shape[1];
246-
m = t.shape[2]; n = t.shape[3];
247-
248-
assert(batch_idx >= 0 && batch_idx < b);
249-
assert(group_idx >= 0 && group_idx < g);
250-
offset = (batch_idx * g + group_idx) * m * n;
251-
slice_shape[0] = m; slice_shape[1] = n;
252-
} else {
253-
assert(0);
254-
}
255-
256-
Tensor res = Tensor_new(slice_shape, t.node != NULL);
257-
memcpy(res.data->flex, t.data->flex + offset, sizeof(float) * m * n);
258-
return res;
259-
}
260-
261-
Tensor Tensor_matmul_batch(Tensor self, Tensor other) {
232+
Tensor Tensor_matmul(Tensor self, Tensor other) {
262233
int self_dim = TensorShape_dim(self.shape);
263234
int other_dim = TensorShape_dim(other.shape);
264235

265-
assert((self_dim == 3 || self_dim == 4) && (other_dim == 3 || other_dim == 4));
236+
assert(self_dim >= 2);
237+
assert(other_dim >= 2);
266238

267-
// broadcasting
268-
int batch = (self.shape[0] > other.shape[0]) ? self.shape[0] : other.shape[0];
239+
int batch_self = (self_dim >= 3) ? self.shape[0] : 1;
240+
int batch_other = (other_dim >= 3) ? other.shape[0] : 1;
241+
int batch = (batch_self > batch_other) ? batch_self : batch_other;
269242

270-
int self_g = (self_dim == 4) ? self.shape[1] : 1;
271-
int other_g = (other_dim == 4) ? other.shape[1] : 1;
272-
int group = (self_g > other_g) ? self_g : other_g;
243+
int group_self = (self_dim == 4) ? self.shape[1] : 1;
244+
int group_other = (other_dim == 4) ? other.shape[1] : 1;
245+
int group = (group_self > group_other) ? group_self : group_other;
273246

274247
int m = self.shape[self_dim - 2];
275248
int n = self.shape[self_dim - 1];
276249
int p = other.shape[other_dim - 1];
277-
// {b,g,m,n} * {b,g,n,p} -> {b,g,m,p} (g=1 for 3D)
278-
279-
assert(n == other.shape[other_dim - 2]);
280250

281-
TensorShape res_shape = {batch, m, p, 0};
282-
if (group > 1) {
283-
res_shape[0] = batch;
284-
res_shape[1] = group;
285-
res_shape[2] = m;
286-
res_shape[3] = p;
287-
}
288-
289-
Tensor res = Tensor_new(res_shape, self.node != NULL || other.node != NULL);
290-
for(int b = 0; b < batch; b++) {
291-
int selfbatch = self.shape[0] <= b ? self.shape[0] - 1 : b;
292-
int otherbatch = other.shape[0] <= b ? other.shape[0] - 1 : b;
293-
294-
for(int g = 0; g < group; g++) {
295-
int selfgroup = self_g <= g ? self_g - 1 : g;
296-
int othergroup = other_g <= g ? other_g - 1 : g;
251+
assert(n == other.shape[other_dim - 2]);
297252

298-
Tensor self_slice = Tensor_batch_slice(self, selfbatch, selfgroup);
299-
Tensor other_slice = Tensor_batch_slice(other, otherbatch, othergroup);
300-
Tensor res_slice = Tensor_matmul(self_slice, other_slice);
253+
bool has4D = (self_dim == 4 || other_dim == 4);
301254

302-
int offset = ((batch > 1) ? b * group + g : g) * m * p;
303-
memcpy(res.data->flex + offset, res_slice.data->flex, sizeof(float) * m * p);
255+
TensorShape res_shape = {0, 0, 0, 0};
256+
if (self_dim <= 2 && other_dim <= 2) {
257+
res_shape[0] = m;
258+
res_shape[1] = p;
259+
} else {
260+
res_shape[0] = batch;
261+
if (has4D) {
262+
res_shape[1] = group;
263+
res_shape[2] = m;
264+
res_shape[3] = p;
265+
} else {
266+
res_shape[1] = m;
267+
res_shape[2] = p;
268+
res_shape[3] = 0;
304269
}
305270
}
306271

307-
if(res.node != NULL) {
308-
res.node->grad_fn = GradFn_matmul;
309-
res.node->inputs[0] = self;
310-
res.node->inputs[1] = other;
311-
res.node->n_inputs = 2;
312-
res.node->name = "MatmulBatch";
313-
}
314-
return res;
315-
}
316-
317-
Tensor Tensor_matmul(Tensor self, Tensor other) {
318-
int self_dim = TensorShape_dim(self.shape);
319-
int other_dim = TensorShape_dim(other.shape);
272+
Tensor res = Tensor_new(res_shape, self.node != NULL || other.node != NULL);
320273

321-
assert(self_dim >= 2);
322-
assert(other_dim >= 2);
274+
for (int b = 0; b < batch; b++) {
275+
int self_b = (batch_self <= b) ? batch_self - 1 : b;
276+
int other_b = (batch_other <= b) ? batch_other - 1 : b;
323277

324-
if (self_dim > 2 || other_dim > 2) {
325-
return Tensor_matmul_batch(self, other);
326-
}
278+
for (int g = 0; g < group; g++) {
279+
int self_g = (group_self <= g) ? group_self - 1 : g;
280+
int other_g = (group_other <= g) ? group_other - 1 : g;
327281

328-
int m = self.shape[self_dim - 2];
329-
int n = self.shape[self_dim - 1];
330-
int p = other.shape[other_dim - 1];
282+
int offset_self = 0;
283+
if (self_dim == 4) {
284+
offset_self = self_b * self.shape[1] * m * n + self_g * m * n;
285+
} else if (self_dim == 3) {
286+
offset_self = self_b * m * n;
287+
}
331288

332-
assert(n == other.shape[other_dim - 2]);
289+
int offset_other = 0;
290+
if (other_dim == 4) {
291+
offset_other = other_b * other.shape[1] * n * p + other_g * n * p;
292+
} else if (other_dim == 3) {
293+
offset_other = other_b * n * p;
294+
}
333295

334-
TensorShape res_shape;
335-
memcpy(res_shape, self.shape, sizeof(TensorShape));
336-
res_shape[self_dim - 1] = p;
296+
int offset_res = ((batch > 1) ? b * group + g : g) * m * p;
337297

338-
// here weight/bias have .node != NULL, so res have GradNode
339-
Tensor res = Tensor_new(res_shape, self.node != NULL || other.node != NULL);
298+
float* self_ptr = self.data->flex + offset_self;
299+
float* other_ptr = other.data->flex + offset_other;
300+
float* res_ptr = res.data->flex + offset_res;
340301

341-
for(int i = 0; i < m; i++) {
342-
for(int j = 0; j < p; j++) {
343-
float sum = 0;
344-
for(int k = 0; k < n; k++) {
345-
sum += self.data->flex[i * n + k] * other.data->flex[k * p + j];
302+
for (int i = 0; i < m; i++) {
303+
for (int j = 0; j < p; j++) {
304+
float sum = 0;
305+
for (int k = 0; k < n; k++) {
306+
sum += self_ptr[i * n + k] * other_ptr[k * p + j];
307+
}
308+
res_ptr[i * p + j] = sum;
309+
}
346310
}
347-
res.data->flex[i * p + j] = sum;
348311
}
349312
}
350313

351-
if(res.node != NULL) {
314+
if (res.node != NULL) {
352315
res.node->grad_fn = GradFn_matmul;
353316
res.node->inputs[0] = self;
354317
res.node->inputs[1] = other;

src/utils.c

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,35 @@ TensorMaxMinResult Tensor_min_dim(Tensor self, int dim) {
222222
return result;
223223
}
224224

225+
Tensor Tensor_batch_slice(Tensor t, int batch_idx, int group_idx) {
226+
int dim = TensorShape_dim(t.shape);
227+
228+
int m, n, offset;
229+
TensorShape slice_shape = {0, 0, 0, 0};
230+
231+
if (dim == 3) {
232+
int b = t.shape[0]; m = t.shape[1]; n = t.shape[2];
233+
assert(batch_idx >= 0 && batch_idx < b);
234+
235+
offset = batch_idx * m * n;
236+
slice_shape[0] = m; slice_shape[1] = n;
237+
} else if (dim == 4) {
238+
int b = t.shape[0], g = t.shape[1];
239+
m = t.shape[2]; n = t.shape[3];
240+
241+
assert(batch_idx >= 0 && batch_idx < b);
242+
assert(group_idx >= 0 && group_idx < g);
243+
offset = (batch_idx * g + group_idx) * m * n;
244+
slice_shape[0] = m; slice_shape[1] = n;
245+
} else {
246+
assert(0);
247+
}
248+
249+
Tensor res = Tensor_new(slice_shape, t.node != NULL);
250+
memcpy(res.data->flex, t.data->flex + offset, sizeof(float) * m * n);
251+
return res;
252+
}
253+
225254
void cten_assert(bool cond, const char* fmt, ...) {
226255
if(!cond) {
227256
va_list args;

tests/Operator/test_matmul.c

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -282,18 +282,17 @@ void test_matmul_operator() {
282282
float exp_d[] = {1.0745f, 1.4433f, 0.7899f, 1.5456f, 1.4509f, 0.6064f,
283283
0.9774f, 0.4197f, 1.1520f, 1.0043f, 1.7620f, 1.9396f, 1.4062f, 1.9461f, 1.9424f,
284284
0.5314f, 0.8391f, 0.8748f, 0.3471f, 1.1284f, 1.1388f, 1.1492f, 1.0333f,
285-
0.8970f, 1.6950f, 0.9817f, 1.0865f, 1.0302f, 0.7693f, 1.6373f};
285+
0.8970f, 1.6950f, 0.9817f, 1.0865f, 1.0302f, 0.7693f, 1.6372f};
286286

287287
Tensor t1 = create_test_tensor(s1_shape, d1, false);
288288
Tensor t2 = create_test_tensor(s2_shape, d2, false);
289289
Tensor expected_res = create_test_tensor(exp_shape, exp_d, false);
290290
Tensor actual_res = Tensor_matmul(t1, t2);
291291

292-
compare_tensors(&actual_res, &expected_res, op_name, tc_name, 1,
293-
TEST_FLOAT_TOLERANCE);
292+
compare_tensors(&actual_res, &expected_res, op_name, tc_name, 1, TEST_FLOAT_TOLERANCE);
294293
}
295294

296-
// Sub-test case 1.1: Batch matrix multiplication using integers only (2x3x4 * 2x4x5)
295+
// Sub-test case 2: Batch matrix multiplication using integers only (2x3x4 * 2x4x5)
297296
{
298297
TensorShape s1_shape = {2, 3, 4};
299298
float d1[] = {
@@ -332,10 +331,10 @@ void test_matmul_operator() {
332331
Tensor expected_res = create_test_tensor(exp_shape, exp_d, false);
333332
Tensor actual_res = Tensor_matmul(t1, t2);
334333

335-
compare_tensors(&actual_res, &expected_res, op_name, tc_name, 5, TEST_FLOAT_TOLERANCE);
334+
compare_tensors(&actual_res, &expected_res, op_name, tc_name, 2, TEST_FLOAT_TOLERANCE);
336335
}
337336

338-
// Sub-test 2: Batch of identity matrices — result should equal second operand
337+
// Sub-test 3: Batch of identity matrices — result should equal second operand
339338
// s1: {3,2,2} (3 identity matrices), s2: {3,2,2}
340339
{
341340
TensorShape s1_shape = {3, 2, 2};
@@ -362,10 +361,10 @@ void test_matmul_operator() {
362361
Tensor expected_res = create_test_tensor(exp_shape, exp_d, false);
363362
Tensor actual_res = Tensor_matmul(t1, t2);
364363

365-
compare_tensors(&actual_res, &expected_res, op_name, tc_name, 2, TEST_FLOAT_TOLERANCE);
364+
compare_tensors(&actual_res, &expected_res, op_name, tc_name, 3, TEST_FLOAT_TOLERANCE);
366365
}
367366

368-
// Sub-test 3: Rectangular per-batch multiply (2 batches): {2,1,3} @ {2,3,2} -> {2,1,2}
367+
// Sub-test 4: Rectangular per-batch multiply (2 batches): {2,1,3} @ {2,3,2} -> {2,1,2}
369368
{
370369
TensorShape s1_shape = {2, 1, 3};
371370
float d1[] = {
@@ -388,10 +387,10 @@ void test_matmul_operator() {
388387
Tensor expected_res = create_test_tensor(exp_shape, exp_d, false);
389388
Tensor actual_res = Tensor_matmul(t1, t2);
390389

391-
compare_tensors(&actual_res, &expected_res, op_name, tc_name, 3, TEST_FLOAT_TOLERANCE);
390+
compare_tensors(&actual_res, &expected_res, op_name, tc_name, 4, TEST_FLOAT_TOLERANCE);
392391
}
393392

394-
// Sub-test 4: Batch of column-result matrices using ones to test reduction (4 batches): {4,2,3}@{4,3,1} -> {4,2,1}
393+
// Sub-test 5: Batch of column-result matrices using ones to test reduction (4 batches): {4,2,3}@{4,3,1} -> {4,2,1}
395394
{
396395
TensorShape s1_shape = {4, 2, 3};
397396
// each 2x3 filled with ones
@@ -411,7 +410,7 @@ void test_matmul_operator() {
411410
Tensor expected_res = create_test_tensor(exp_shape, exp_d, false);
412411
Tensor actual_res = Tensor_matmul(t1, t2);
413412

414-
compare_tensors(&actual_res, &expected_res, op_name, tc_name, 4, TEST_FLOAT_TOLERANCE);
413+
compare_tensors(&actual_res, &expected_res, op_name, tc_name, 5, TEST_FLOAT_TOLERANCE);
415414
}
416415
}
417416

@@ -463,10 +462,10 @@ void test_matmul_operator() {
463462

464463
TensorShape exp_shape = {4, 3};
465464
float exp_d[] = {
466-
0.7616f, 1.3740f, 1.5423f, // Row 0
467-
0.3637f, 1.5308f, 1.3906f, // Row 1
468-
0.5558f, 1.1748f, 1.4725f, // Row 2
469-
0.3675f, 0.9730f, 0.9582f, // Row 3
465+
0.7617f, 1.3740f, 1.5422f, // Row 0
466+
0.3638f, 1.5307f, 1.3906f, // Row 1
467+
0.5559f, 1.1747f, 1.4724f, // Row 2
468+
0.3675f, 0.9729f, 0.9581f, // Row 3
470469
};
471470

472471
Tensor t1 = create_test_tensor(s1_shape, d1, false);
@@ -537,8 +536,8 @@ void test_matmul_operator() {
537536
0.4677f, 0.3816f,
538537
1.1133f, 0.8875f,
539538

540-
0.8504f, 0.6607f,
541-
0.3593f, 0.1660f,
539+
0.8505f, 0.6607f,
540+
0.3593f, 0.1659f,
542541
};
543542

544543
Tensor t1 = create_test_tensor(s1_shape, d1, false);
@@ -604,15 +603,16 @@ void test_matmul_operator() {
604603
4.0f, 5.0f,
605604
10.0f, 11.0f,
606605

607-
5.0f, 8.0f,
608-
14.0f, 20.0f,
606+
4.0f, 8.0f,
607+
13.0f, 20.0f,
609608
};
610609

611610
Tensor t1 = create_test_tensor(s1_shape, d1, false);
612611
Tensor t2 = create_test_tensor(s2_shape, d2, false);
613612
Tensor expected_res = create_test_tensor(exp_shape, exp_d, false);
614613
Tensor actual_res = Tensor_matmul(t1, t2);
615614

615+
616616
compare_tensors(&actual_res, &expected_res, op_name, tc_name, 5, TEST_FLOAT_TOLERANCE);
617617
}
618618

@@ -643,14 +643,14 @@ void test_matmul_operator() {
643643
1.0f, 2.0f,
644644
1.0f, 1.0f,
645645

646-
2.0f, 2.0f,
646+
3.0f, 2.0f,
647647
3.0f, 2.0f,
648648

649649
1.0f, 1.0f,
650650
2.0f, 1.0f,
651651

652-
0.0f, 2.0f,
653652
1.0f, 3.0f,
653+
3.0f, 3.0f,
654654
};
655655

656656
Tensor t1 = create_test_tensor(s1_shape, d1, false);

0 commit comments

Comments
 (0)