Use slice-based indexing in separate_mtmvn for non-interleaved case (#3246)#3246
Closed
saitcakmak wants to merge 1 commit intometa-pytorch:mainfrom
Closed
Use slice-based indexing in separate_mtmvn for non-interleaved case (#3246)#3246saitcakmak wants to merge 1 commit intometa-pytorch:mainfrom
saitcakmak wants to merge 1 commit intometa-pytorch:mainfrom
Conversation
|
@saitcakmak has imported this pull request. If you are a Meta employee, you can view this in D97354351. |
…eta-pytorch#3246) Summary: For non-interleaved MultitaskMultivariateNormal, task covariance blocks are contiguous, so we can use slice indexing (start:end) instead of tensor indexing. LinearOperator handles slices much more efficiently, avoiding excessive CatLinearOperator creation that caused the performance regression from meta-pytorch#2920. The interleaved case retains tensor indexing since indices are strided. Added more rigorous tests with larger data, more tasks, and batch dims. Partially addresses meta-pytorch#3095 ## Profiling Results: `separate_mtmvn` (non-interleaved) Old = tensor indexing, New = slice-based indexing ### Runtime (median) | Config | Old | New | Speedup | |---|---|---|---| | 10 data, 3 tasks (30x30) | 174 us | 157 us | 1.1x | | 50 data, 4 tasks (200x200) | 455 us | 138 us | 3.3x | | 100 data, 5 tasks (500x500) | 890 us | 207 us | 4.3x | | 200 data, 5 tasks (1000x1000) | 1.26 ms | 152 us | 8.3x | | 50 data, 4 tasks, batch=3 (200x200) | 721 us | 148 us | 4.9x | | 100 data, 5 tasks, batch=3 (500x500) | 1.01 ms | 179 us | 5.6x | ### Peak Memory | Config | Old | New | Reduction | |---|---|---|---| | 10 data, 3 tasks (30x30) | 7.2 KB | 5.8 KB | 1.23x | | 50 data, 4 tasks (200x200) | 8.5 KB | 6.8 KB | 1.25x | | 100 data, 5 tasks (500x500) | 9.8 KB | 7.9 KB | 1.25x | | 200 data, 5 tasks (1000x1000) | 9.8 KB | 7.7 KB | 1.26x | | 50 data, 4 tasks, batch=3 (200x200) | 8.8 KB | 6.8 KB | 1.28x | | 100 data, 5 tasks, batch=3 (500x500) | 10.2 KB | 8.0 KB | 1.28x | Key takeaways: 1. Runtime: The slice-based indexing is significantly faster, with speedups ranging from 1.1x at small sizes to 8.3x at 1000x1000. The new implementation stays roughly constant (~150 us) regardless of matrix size since slicing is O(1) for LinearOperators, while the old tensor-indexing approach scales with matrix size. 2. Memory: Modest ~25% reduction in peak memory across all configurations. The memory savings are small in absolute terms because the profiling measures only the indexing overhead (the covariance matrix itself dominates memory in real usage). 3. Scaling: Slice-based indexing is O(1) for LinearOperators vs O(n) for tensor indexing, so the speedup grows with matrix size. ### Profiling script ```python #!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Profile memory and runtime of separate_mtmvn: old (tensor indexing) vs new (slice-based indexing) for the non-interleaved case. Usage: python scripts/profile_separate_mtmvn.py """ from __future__ import annotations import gc import time import tracemalloc from contextlib import contextmanager from typing import Callable import torch from gpytorch.distributions import MultitaskMultivariateNormal from gpytorch.distributions.multivariate_normal import MultivariateNormal from linear_operator import to_linear_operator # --------------------------------------------------------------------------- # Old implementation (tensor indexing for both branches) # --------------------------------------------------------------------------- def separate_mtmvn_old( mvn: MultitaskMultivariateNormal, ) -> list[MultivariateNormal]: full_covar = mvn.lazy_covariance_matrix num_data, num_tasks = mvn.mean.shape[-2:] mvns = [] for c in range(num_tasks): if mvn._interleaved: task_indices = torch.arange( c, num_data * num_tasks, num_tasks, device=full_covar.device ) else: task_indices = torch.arange( c * num_data, (c + 1) * num_data, device=full_covar.device ) task_covar = full_covar[..., task_indices, :] task_covar = task_covar[..., :, task_indices] mvns.append( MultivariateNormal(mvn.mean[..., c], to_linear_operator(task_covar)) ) return mvns # --------------------------------------------------------------------------- # New implementation (slice-based indexing for non-interleaved) # --------------------------------------------------------------------------- def separate_mtmvn_new( mvn: MultitaskMultivariateNormal, ) -> list[MultivariateNormal]: full_covar = mvn.lazy_covariance_matrix num_data, num_tasks = mvn.mean.shape[-2:] mvns = [] for c in range(num_tasks): if mvn._interleaved: task_indices = torch.arange( c, num_data * num_tasks, num_tasks, device=full_covar.device ) task_covar = full_covar[..., task_indices, :] task_covar = task_covar[..., :, task_indices] else: start = c * num_data end = start + num_data task_covar = full_covar[..., start:end, start:end] mvns.append( MultivariateNormal(mvn.mean[..., c], to_linear_operator(task_covar)) ) return mvns # --------------------------------------------------------------------------- # Profiling helpers # --------------------------------------------------------------------------- contextmanager def track_memory(): """Context manager that returns peak memory usage in bytes.""" gc.collect() tracemalloc.start() snapshot_start = tracemalloc.take_snapshot() result = {"peak_bytes": 0, "current_bytes": 0} try: yield result finally: _, peak = tracemalloc.get_traced_memory() current = tracemalloc.get_traced_memory()[0] result["peak_bytes"] = peak result["current_bytes"] = current tracemalloc.stop() def benchmark_runtime( fn: Callable, mvn: MultitaskMultivariateNormal, warmup: int = 3, repeats: int = 10, materialize: bool = False, ) -> dict[str, float]: """Time `fn(mvn)` and optionally materializing the covariances.""" # Warmup for _ in range(warmup): result = fn(mvn) if materialize: for m in result: m.covariance_matrix # noqa: B018 -- force evaluation times = [] for _ in range(repeats): gc.collect() t0 = time.perf_counter() result = fn(mvn) if materialize: for m in result: m.covariance_matrix # noqa: B018 t1 = time.perf_counter() times.append(t1 - t0) return { "mean_s": sum(times) / len(times), "min_s": min(times), "max_s": max(times), "median_s": sorted(times)[len(times) // 2], } def benchmark_memory( fn: Callable, mvn: MultitaskMultivariateNormal, materialize: bool = False, ) -> dict[str, int]: """Measure peak memory of `fn(mvn)` including optional materialization.""" gc.collect() with track_memory() as mem: result = fn(mvn) if materialize: for m in result: m.covariance_matrix # noqa: B018 return mem def build_mtmvn( num_data: int, num_tasks: int, batch_shape: torch.Size = torch.Size([]), interleaved: bool = False, dtype: torch.dtype = torch.float64, ) -> MultitaskMultivariateNormal: """Build a MultitaskMultivariateNormal for testing.""" n = num_data * num_tasks mean = torch.rand(*batch_shape, num_data, num_tasks, dtype=dtype) a = torch.rand(*batch_shape, n, n, dtype=dtype) covar = a @ a.transpose(-1, -2) + torch.eye(n, dtype=dtype) return MultitaskMultivariateNormal( mean=mean, covariance_matrix=covar, interleaved=interleaved ) def fmt_bytes(b: int) -> str: if b < 1024: return f"{b} B" elif b < 1024**2: return f"{b / 1024:.1f} KB" else: return f"{b / 1024**2:.2f} MB" def fmt_time(s: float) -> str: if s < 1e-3: return f"{s * 1e6:.1f} us" elif s < 1: return f"{s * 1e3:.2f} ms" else: return f"{s:.3f} s" # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: configs = [ # (num_data, num_tasks, batch_shape) (10, 3, torch.Size([])), (50, 4, torch.Size([])), (100, 5, torch.Size([])), (200, 5, torch.Size([])), (50, 4, torch.Size([3])), (100, 5, torch.Size([3])), ] print("=" * 90) print("Profile: separate_mtmvn -- old (tensor index) vs new (slice index)") print("Non-interleaved case only (the interleaved path is unchanged)") print("=" * 90) for num_data, num_tasks, batch_shape in configs: n = num_data * num_tasks label = ( f"num_data={num_data}, num_tasks={num_tasks}, " f"batch={list(batch_shape)}, matrix_size={n}x{n}" ) print(f"\n{'─' * 90}") print(f" {label}") print(f"{'─' * 90}") mvn = build_mtmvn( num_data=num_data, num_tasks=num_tasks, batch_shape=batch_shape, interleaved=False, ) # -- Runtime without materialization (lazy) -- rt_old_lazy = benchmark_runtime(separate_mtmvn_old, mvn, materialize=False) rt_new_lazy = benchmark_runtime(separate_mtmvn_new, mvn, materialize=False) # -- Runtime with materialization -- rt_old_mat = benchmark_runtime(separate_mtmvn_old, mvn, materialize=True) rt_new_mat = benchmark_runtime(separate_mtmvn_new, mvn, materialize=True) # -- Memory without materialization -- mem_old_lazy = benchmark_memory(separate_mtmvn_old, mvn, materialize=False) mem_new_lazy = benchmark_memory(separate_mtmvn_new, mvn, materialize=False) # -- Memory with materialization -- mem_old_mat = benchmark_memory(separate_mtmvn_old, mvn, materialize=True) mem_new_mat = benchmark_memory(separate_mtmvn_new, mvn, materialize=True) def speedup(old: float, new: float) -> str: if new == 0: return "inf" ratio = old / new return f"{ratio:.2f}x" def mem_ratio(old: int, new: int) -> str: if new == 0: return "N/A" ratio = old / new return f"{ratio:.2f}x" print(f"\n {'Metric':<35} {'Old':>12} {'New':>12} {'Speedup':>10}") print(f" {'-' * 71}") print( f" {'Runtime (lazy, median)':<35} " f"{fmt_time(rt_old_lazy['median_s']):>12} " f"{fmt_time(rt_new_lazy['median_s']):>12} " f"{speedup(rt_old_lazy['median_s'], rt_new_lazy['median_s']):>10}" ) print( f" {'Runtime (materialized, median)':<35} " f"{fmt_time(rt_old_mat['median_s']):>12} " f"{fmt_time(rt_new_mat['median_s']):>12} " f"{speedup(rt_old_mat['median_s'], rt_new_mat['median_s']):>10}" ) print( f" {'Peak memory (lazy)':<35} " f"{fmt_bytes(mem_old_lazy['peak_bytes']):>12} " f"{fmt_bytes(mem_new_lazy['peak_bytes']):>12} " f"{mem_ratio(mem_old_lazy['peak_bytes'], mem_new_lazy['peak_bytes']):>10}" ) print( f" {'Peak memory (materialized)':<35} " f"{fmt_bytes(mem_old_mat['peak_bytes']):>12} " f"{fmt_bytes(mem_new_mat['peak_bytes']):>12} " f"{mem_ratio(mem_old_mat['peak_bytes'], mem_new_mat['peak_bytes']):>10}" ) print(f"\n{'=' * 90}") print("Done.") if __name__ == "__main__": main() ``` Reviewed By: esantorella Differential Revision: D97354351 Pulled By: saitcakmak
936d72b to
010e0d6
Compare
|
@saitcakmak has exported this pull request. If you are a Meta employee, you can view the originating Diff in D97354351. |
|
@saitcakmak merged this pull request in 050f9ec. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
For non-interleaved MultitaskMultivariateNormal, task covariance blocks are contiguous, so we can use slice indexing (start:end) instead of tensor indexing. LinearOperator handles slices much more efficiently, avoiding excessive CatLinearOperator creation that caused the performance regression from #2920.
The interleaved case retains tensor indexing since indices are strided.
Added more rigorous tests with larger data, more tasks, and batch dims.
Partially addresses #3095
Profiling Results:
separate_mtmvn(non-interleaved)Old = tensor indexing, New = slice-based indexing
Runtime (median)
Peak Memory
Key takeaways:
Profiling script
Reviewed By: esantorella
Differential Revision: D97354351
Pulled By: saitcakmak