Skip to content

Commit 6a5e3a8

Browse files
committed
perf: Optimize 2D tensor trace calculation
1 parent 8bde169 commit 6a5e3a8

File tree

1 file changed

+8
-5
lines changed
  • deep_causality_tensor/src/types/causal_tensor/op_tensor_ein_sum

1 file changed

+8
-5
lines changed

deep_causality_tensor/src/types/causal_tensor/op_tensor_ein_sum/ein_sum_impl.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -352,12 +352,15 @@ where
352352
.collect();
353353

354354
if new_shape.is_empty() {
355+
// This is a 2D tensor trace, resulting in a scalar.
355356
let mut total_sum = T::default();
356-
for i in 0..tensor.shape[axis1] {
357-
let mut index = vec![0; tensor.num_dim()];
358-
index[axis1] = i;
359-
index[axis2] = i;
360-
total_sum = total_sum + tensor.get(&index).unwrap().clone();
357+
let dim = tensor.shape[axis1];
358+
let stride1 = tensor.strides[axis1];
359+
let stride2 = tensor.strides[axis2];
360+
361+
for i in 0..dim {
362+
let flat_index = i * stride1 + i * stride2;
363+
total_sum = total_sum + tensor.data[flat_index].clone();
361364
}
362365
return CausalTensor::new(vec![total_sum], vec![]);
363366
}

0 commit comments

Comments
 (0)