Skip to content

feat: chunked/streaming cross-covariance accumulation for large N #29

@hunter-heidenreich

Description

@hunter-heidenreich

Goal

Add an optional chunk_size parameter to kabsch (and kabsch_umeyama) enabling incremental accumulation of the cross-covariance matrix H for very large N without requiring the full [B, N, D] tensor in memory at once.

Motivation

The cross-covariance H = P^T Q is computed in a single batched matmul over all N points. For large N (e.g. dense LiDAR scans, 10^5 -- 10^6 points per cloud), the intermediate tensors p [B, N, D] and q [B, N, D] may exceed GPU memory, causing OOM errors. There is currently no workaround.

H can be accumulated in chunks:

H = sum over chunks of: p_chunk^T @ q_chunk

This is mathematically equivalent and reduces peak memory from O(B·N·D) to O(B·chunk_size·D).

Proposed API

kabsch(P, Q, chunk_size=None)  # None = current behavior (full N at once)
kabsch(P, Q, chunk_size=1024)  # accumulate H in blocks of 1024 points

The centroid must be computed first (or provided externally -- a separate design question).

Design Considerations

  • Centroid computation still requires a pass over all N; this can be done in the same chunked loop
  • The chunk loop is a Python-level loop over N, which breaks JAX JIT tracing for dynamic chunk sizes. Options:
    • JAX: require static chunk_size (known at trace time), or use jax.lax.scan
    • PyTorch/TF/NumPy: Python loop is fine
  • This feature is most useful for NumPy and PyTorch; JAX's lax.scan version would be a bonus
  • Horn does not benefit (same H construction, same issue)

Acceptance Criteria

  • chunk_size parameter added to kabsch and kabsch_umeyama in at least NumPy and PyTorch backends
  • Correctness tests: chunked result matches full-N result to float32 tolerance
  • Memory test: verify that peak memory for a large N input is bounded by chunk_size not N
  • Documentation of the parameter and its tradeoffs

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestmoderateModerate impact, fix when possibleperformanceRuntime performance improvement

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions