-
Notifications
You must be signed in to change notification settings - Fork 0
feat: chunked/streaming cross-covariance accumulation for large N #29
Copy link
Copy link
Open
Labels
enhancementNew feature or requestNew feature or requestmoderateModerate impact, fix when possibleModerate impact, fix when possibleperformanceRuntime performance improvementRuntime performance improvement
Description
Goal
Add an optional chunk_size parameter to kabsch (and kabsch_umeyama) enabling incremental accumulation of the cross-covariance matrix H for very large N without requiring the full [B, N, D] tensor in memory at once.
Motivation
The cross-covariance H = P^T Q is computed in a single batched matmul over all N points. For large N (e.g. dense LiDAR scans, 10^5 -- 10^6 points per cloud), the intermediate tensors p [B, N, D] and q [B, N, D] may exceed GPU memory, causing OOM errors. There is currently no workaround.
H can be accumulated in chunks:
H = sum over chunks of: p_chunk^T @ q_chunk
This is mathematically equivalent and reduces peak memory from O(B·N·D) to O(B·chunk_size·D).
Proposed API
kabsch(P, Q, chunk_size=None) # None = current behavior (full N at once)
kabsch(P, Q, chunk_size=1024) # accumulate H in blocks of 1024 pointsThe centroid must be computed first (or provided externally -- a separate design question).
Design Considerations
- Centroid computation still requires a pass over all N; this can be done in the same chunked loop
- The chunk loop is a Python-level loop over N, which breaks JAX JIT tracing for dynamic chunk sizes. Options:
- JAX: require static
chunk_size(known at trace time), or usejax.lax.scan - PyTorch/TF/NumPy: Python loop is fine
- JAX: require static
- This feature is most useful for NumPy and PyTorch; JAX's
lax.scanversion would be a bonus - Horn does not benefit (same H construction, same issue)
Acceptance Criteria
chunk_sizeparameter added tokabschandkabsch_umeyamain at least NumPy and PyTorch backends- Correctness tests: chunked result matches full-N result to float32 tolerance
- Memory test: verify that peak memory for a large N input is bounded by
chunk_sizenot N - Documentation of the parameter and its tradeoffs
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or requestmoderateModerate impact, fix when possibleModerate impact, fix when possibleperformanceRuntime performance improvementRuntime performance improvement