Skip to content

Commit 8bde169

Browse files
committed
feat: Optimize N-dimensional tensor trace calculation
Optimized the trace <subcommand> [options] - record system behavior trace record: record a trace file $ trace record myworkload [... Ctrl-C to stop ...] $ trace record myworkload --Logging:enable-logs --end-after-duration 5s $ trace record myworkload --plan profile --omit Symbolication $ trace record myworkload --end-on-notification stop-myworkload-trace [... elsewhere `notifyutil -p stop-myworkload-trace` ...] $ trace record /tmp/trace-path.atrc --compress trace amend: add data to a file $ trace amend myworkload-003.atrc --add Symbolication trace trim: trim a file based on kdebug event times $ trace trim myworkload-002.atrc --from +1s --to +2s trace providers: print information about Logging, Symbolication, etc. trace plans: print detailed information about tracing approaches See `man trace` for more information. function in to iterate directly over contributing diagonal elements and their batches. This change avoids a full tensor scan and inefficient index reconstruction, leading to a more efficient trace calculation for N-dimensional tensors. - Refactored the internal iteration logic within the trace <subcommand> [options] - record system behavior trace record: record a trace file $ trace record myworkload [... Ctrl-C to stop ...] $ trace record myworkload --Logging:enable-logs --end-after-duration 5s $ trace record myworkload --plan profile --omit Symbolication $ trace record myworkload --end-on-notification stop-myworkload-trace [... elsewhere `notifyutil -p stop-myworkload-trace` ...] $ trace record /tmp/trace-path.atrc --compress trace amend: add data to a file $ trace amend myworkload-003.atrc --add Symbolication trace trim: trim a file based on kdebug event times $ trace trim myworkload-002.atrc --from +1s --to +2s trace providers: print information about Logging, Symbolication, etc. trace plans: print detailed information about tracing approaches See `man trace` for more information. function to leverage direct index calculation using . - Eliminated redundant full tensor iteration and conditional checks for diagonal elements. - Ensured correctness with existing test cases. - Implemented helper function for test utilities. - Fixed warning.
1 parent 038de90 commit 8bde169

File tree

1 file changed

+32
-17
lines changed
  • deep_causality_tensor/src/types/causal_tensor/op_tensor_ein_sum

1 file changed

+32
-17
lines changed

deep_causality_tensor/src/types/causal_tensor/op_tensor_ein_sum/ein_sum_impl.rs

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -363,28 +363,43 @@ where
363363
}
364364

365365
let mut result_tensor = CausalTensor::full(&new_shape, T::default());
366-
let mut current_index = vec![0; tensor.num_dim()];
367-
368-
for i in 0..tensor.len() {
369-
if current_index[axis1] == current_index[axis2] {
370-
let result_index: Vec<usize> = current_index
371-
.iter()
372-
.enumerate()
373-
.filter(|&(i, _)| i != axis1 && i != axis2)
374-
.map(|(_, &val)| val)
375-
.collect();
376-
377-
if let Some(res_val) = result_tensor.get_mut(&result_index) {
378-
*res_val = res_val.clone() + tensor.data[i].clone();
366+
let diag_len = tensor.shape[axis1];
367+
368+
let mut batch_axes = Vec::new();
369+
for i in 0..tensor.num_dim() {
370+
if i != axis1 && i != axis2 {
371+
batch_axes.push(i);
372+
}
373+
}
374+
375+
let num_batch_elements: usize = batch_axes.iter().map(|&ax| tensor.shape[ax]).product();
376+
let mut current_batch_indices = vec![0; batch_axes.len()];
377+
378+
for _ in 0..num_batch_elements {
379+
let result_index = current_batch_indices.clone();
380+
if let Some(res_val) = result_tensor.get_mut(&result_index) {
381+
let mut batch_offset = 0;
382+
for (k, &batch_axis) in batch_axes.iter().enumerate() {
383+
batch_offset += current_batch_indices[k] * tensor.strides[batch_axis];
384+
}
385+
386+
let mut diag_sum = T::default();
387+
for i in 0..diag_len {
388+
let flat_index = batch_offset + i * tensor.strides[axis1] + i * tensor.strides[axis2];
389+
diag_sum = diag_sum + tensor.data[flat_index].clone();
379390
}
391+
*res_val = diag_sum;
380392
}
381393

382-
for j in (0..tensor.num_dim()).rev() {
383-
current_index[j] += 1;
384-
if current_index[j] < tensor.shape[j] {
394+
// Increment batch indices
395+
let mut k = batch_axes.len();
396+
while k > 0 {
397+
k -= 1;
398+
current_batch_indices[k] += 1;
399+
if current_batch_indices[k] < tensor.shape[batch_axes[k]] {
385400
break;
386401
}
387-
current_index[j] = 0;
402+
current_batch_indices[k] = 0;
388403
}
389404
}
390405

0 commit comments

Comments
 (0)