Skip to content

Commit 695299b

Browse files
committed
The mat_mul_2d function has been updated to use direct array access with
flat indices, improving performance and removing potential panics. Signed-off-by: Marvin Hansen <[email protected]>
1 parent 5756f9d commit 695299b

File tree

1 file changed

+11
-2
lines changed
  • deep_causality_tensor/src/types/causal_tensor/op_tensor_ein_sum

1 file changed

+11
-2
lines changed

deep_causality_tensor/src/types/causal_tensor/op_tensor_ein_sum/ein_sum_impl.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,12 +254,21 @@ where
254254

255255
let mut result_data = vec![T::default(); m * n];
256256

257+
let lhs_row_stride_i = lhs.strides[0];
258+
let lhs_col_stride = lhs.strides[1];
259+
let rhs_row_stride = rhs.strides[0];
260+
let rhs_col_stride_j = rhs.strides[1];
261+
257262
for i in 0..m {
263+
let lhs_row_base = i * lhs_row_stride_i;
258264
for j in 0..n {
265+
let rhs_col_base = j * rhs_col_stride_j;
259266
let mut sum = T::default();
260267
for l in 0..k {
261-
let lhs_val = lhs.get(&[i, l]).unwrap().clone();
262-
let rhs_val = rhs.get(&[l, j]).unwrap().clone();
268+
let lhs_idx = lhs_row_base + l * lhs_col_stride;
269+
let rhs_idx = l * rhs_row_stride + rhs_col_base;
270+
let lhs_val = lhs.data[lhs_idx].clone();
271+
let rhs_val = rhs.data[rhs_idx].clone();
263272
sum = sum + lhs_val * rhs_val;
264273
}
265274
result_data[i * n + j] = sum;

0 commit comments

Comments
 (0)