File tree Expand file tree Collapse file tree 1 file changed +8
-5
lines changed
deep_causality_tensor/src/types/causal_tensor/op_tensor_ein_sum Expand file tree Collapse file tree 1 file changed +8
-5
lines changed Original file line number Diff line number Diff line change @@ -352,12 +352,15 @@ where
352352 . collect ( ) ;
353353
354354 if new_shape. is_empty ( ) {
355+ // This is a 2D tensor trace, resulting in a scalar.
355356 let mut total_sum = T :: default ( ) ;
356- for i in 0 ..tensor. shape [ axis1] {
357- let mut index = vec ! [ 0 ; tensor. num_dim( ) ] ;
358- index[ axis1] = i;
359- index[ axis2] = i;
360- total_sum = total_sum + tensor. get ( & index) . unwrap ( ) . clone ( ) ;
357+ let dim = tensor. shape [ axis1] ;
358+ let stride1 = tensor. strides [ axis1] ;
359+ let stride2 = tensor. strides [ axis2] ;
360+
361+ for i in 0 ..dim {
362+ let flat_index = i * stride1 + i * stride2;
363+ total_sum = total_sum + tensor. data [ flat_index] . clone ( ) ;
361364 }
362365 return CausalTensor :: new ( vec ! [ total_sum] , vec ! [ ] ) ;
363366 }
You can’t perform that action at this time.
0 commit comments