Skip to content

Commit 050f9ec

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Use slice-based indexing in separate_mtmvn for non-interleaved case (#3246)
Summary: For non-interleaved MultitaskMultivariateNormal, task covariance blocks are contiguous, so we can use slice indexing (start:end) instead of tensor indexing. LinearOperator handles slices much more efficiently, avoiding excessive CatLinearOperator creation that caused the performance regression from #2920. The interleaved case retains tensor indexing since indices are strided. Added more rigorous tests with larger data, more tasks, and batch dims. Partially addresses #3095 ## Profiling Results: `separate_mtmvn` (non-interleaved) Old = tensor indexing, New = slice-based indexing ### Runtime (median) | Config | Old | New | Speedup | |---|---|---|---| | 10 data, 3 tasks (30x30) | 174 us | 157 us | 1.1x | | 50 data, 4 tasks (200x200) | 455 us | 138 us | 3.3x | | 100 data, 5 tasks (500x500) | 890 us | 207 us | 4.3x | | 200 data, 5 tasks (1000x1000) | 1.26 ms | 152 us | 8.3x | | 50 data, 4 tasks, batch=3 (200x200) | 721 us | 148 us | 4.9x | | 100 data, 5 tasks, batch=3 (500x500) | 1.01 ms | 179 us | 5.6x | ### Peak Memory | Config | Old | New | Reduction | |---|---|---|---| | 10 data, 3 tasks (30x30) | 7.2 KB | 5.8 KB | 1.23x | | 50 data, 4 tasks (200x200) | 8.5 KB | 6.8 KB | 1.25x | | 100 data, 5 tasks (500x500) | 9.8 KB | 7.9 KB | 1.25x | | 200 data, 5 tasks (1000x1000) | 9.8 KB | 7.7 KB | 1.26x | | 50 data, 4 tasks, batch=3 (200x200) | 8.8 KB | 6.8 KB | 1.28x | | 100 data, 5 tasks, batch=3 (500x500) | 10.2 KB | 8.0 KB | 1.28x | Key takeaways: 1. Runtime: The slice-based indexing is significantly faster, with speedups ranging from 1.1x at small sizes to 8.3x at 1000x1000. The new implementation stays roughly constant (~150 us) regardless of matrix size since slicing is O(1) for LinearOperators, while the old tensor-indexing approach scales with matrix size. 2. Memory: Modest ~25% reduction in peak memory across all configurations. The memory savings are small in absolute terms because the profiling measures only the indexing overhead (the covariance matrix itself dominates memory in real usage). 3. Scaling: Slice-based indexing is O(1) for LinearOperators vs O(n) for tensor indexing, so the speedup grows with matrix size. ### Profiling script ```python #!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Profile memory and runtime of separate_mtmvn: old (tensor indexing) vs new (slice-based indexing) for the non-interleaved case. Usage: python scripts/profile_separate_mtmvn.py """ from __future__ import annotations import gc import time import tracemalloc from contextlib import contextmanager from typing import Callable import torch from gpytorch.distributions import MultitaskMultivariateNormal from gpytorch.distributions.multivariate_normal import MultivariateNormal from linear_operator import to_linear_operator # --------------------------------------------------------------------------- # Old implementation (tensor indexing for both branches) # --------------------------------------------------------------------------- def separate_mtmvn_old( mvn: MultitaskMultivariateNormal, ) -> list[MultivariateNormal]: full_covar = mvn.lazy_covariance_matrix num_data, num_tasks = mvn.mean.shape[-2:] mvns = [] for c in range(num_tasks): if mvn._interleaved: task_indices = torch.arange( c, num_data * num_tasks, num_tasks, device=full_covar.device ) else: task_indices = torch.arange( c * num_data, (c + 1) * num_data, device=full_covar.device ) task_covar = full_covar[..., task_indices, :] task_covar = task_covar[..., :, task_indices] mvns.append( MultivariateNormal(mvn.mean[..., c], to_linear_operator(task_covar)) ) return mvns # --------------------------------------------------------------------------- # New implementation (slice-based indexing for non-interleaved) # --------------------------------------------------------------------------- def separate_mtmvn_new( mvn: MultitaskMultivariateNormal, ) -> list[MultivariateNormal]: full_covar = mvn.lazy_covariance_matrix num_data, num_tasks = mvn.mean.shape[-2:] mvns = [] for c in range(num_tasks): if mvn._interleaved: task_indices = torch.arange( c, num_data * num_tasks, num_tasks, device=full_covar.device ) task_covar = full_covar[..., task_indices, :] task_covar = task_covar[..., :, task_indices] else: start = c * num_data end = start + num_data task_covar = full_covar[..., start:end, start:end] mvns.append( MultivariateNormal(mvn.mean[..., c], to_linear_operator(task_covar)) ) return mvns # --------------------------------------------------------------------------- # Profiling helpers # --------------------------------------------------------------------------- contextmanager def track_memory(): """Context manager that returns peak memory usage in bytes.""" gc.collect() tracemalloc.start() snapshot_start = tracemalloc.take_snapshot() result = {"peak_bytes": 0, "current_bytes": 0} try: yield result finally: _, peak = tracemalloc.get_traced_memory() current = tracemalloc.get_traced_memory()[0] result["peak_bytes"] = peak result["current_bytes"] = current tracemalloc.stop() def benchmark_runtime( fn: Callable, mvn: MultitaskMultivariateNormal, warmup: int = 3, repeats: int = 10, materialize: bool = False, ) -> dict[str, float]: """Time `fn(mvn)` and optionally materializing the covariances.""" # Warmup for _ in range(warmup): result = fn(mvn) if materialize: for m in result: m.covariance_matrix # noqa: B018 -- force evaluation times = [] for _ in range(repeats): gc.collect() t0 = time.perf_counter() result = fn(mvn) if materialize: for m in result: m.covariance_matrix # noqa: B018 t1 = time.perf_counter() times.append(t1 - t0) return { "mean_s": sum(times) / len(times), "min_s": min(times), "max_s": max(times), "median_s": sorted(times)[len(times) // 2], } def benchmark_memory( fn: Callable, mvn: MultitaskMultivariateNormal, materialize: bool = False, ) -> dict[str, int]: """Measure peak memory of `fn(mvn)` including optional materialization.""" gc.collect() with track_memory() as mem: result = fn(mvn) if materialize: for m in result: m.covariance_matrix # noqa: B018 return mem def build_mtmvn( num_data: int, num_tasks: int, batch_shape: torch.Size = torch.Size([]), interleaved: bool = False, dtype: torch.dtype = torch.float64, ) -> MultitaskMultivariateNormal: """Build a MultitaskMultivariateNormal for testing.""" n = num_data * num_tasks mean = torch.rand(*batch_shape, num_data, num_tasks, dtype=dtype) a = torch.rand(*batch_shape, n, n, dtype=dtype) covar = a @ a.transpose(-1, -2) + torch.eye(n, dtype=dtype) return MultitaskMultivariateNormal( mean=mean, covariance_matrix=covar, interleaved=interleaved ) def fmt_bytes(b: int) -> str: if b < 1024: return f"{b} B" elif b < 1024**2: return f"{b / 1024:.1f} KB" else: return f"{b / 1024**2:.2f} MB" def fmt_time(s: float) -> str: if s < 1e-3: return f"{s * 1e6:.1f} us" elif s < 1: return f"{s * 1e3:.2f} ms" else: return f"{s:.3f} s" # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: configs = [ # (num_data, num_tasks, batch_shape) (10, 3, torch.Size([])), (50, 4, torch.Size([])), (100, 5, torch.Size([])), (200, 5, torch.Size([])), (50, 4, torch.Size([3])), (100, 5, torch.Size([3])), ] print("=" * 90) print("Profile: separate_mtmvn -- old (tensor index) vs new (slice index)") print("Non-interleaved case only (the interleaved path is unchanged)") print("=" * 90) for num_data, num_tasks, batch_shape in configs: n = num_data * num_tasks label = ( f"num_data={num_data}, num_tasks={num_tasks}, " f"batch={list(batch_shape)}, matrix_size={n}x{n}" ) print(f"\n{'─' * 90}") print(f" {label}") print(f"{'─' * 90}") mvn = build_mtmvn( num_data=num_data, num_tasks=num_tasks, batch_shape=batch_shape, interleaved=False, ) # -- Runtime without materialization (lazy) -- rt_old_lazy = benchmark_runtime(separate_mtmvn_old, mvn, materialize=False) rt_new_lazy = benchmark_runtime(separate_mtmvn_new, mvn, materialize=False) # -- Runtime with materialization -- rt_old_mat = benchmark_runtime(separate_mtmvn_old, mvn, materialize=True) rt_new_mat = benchmark_runtime(separate_mtmvn_new, mvn, materialize=True) # -- Memory without materialization -- mem_old_lazy = benchmark_memory(separate_mtmvn_old, mvn, materialize=False) mem_new_lazy = benchmark_memory(separate_mtmvn_new, mvn, materialize=False) # -- Memory with materialization -- mem_old_mat = benchmark_memory(separate_mtmvn_old, mvn, materialize=True) mem_new_mat = benchmark_memory(separate_mtmvn_new, mvn, materialize=True) def speedup(old: float, new: float) -> str: if new == 0: return "inf" ratio = old / new return f"{ratio:.2f}x" def mem_ratio(old: int, new: int) -> str: if new == 0: return "N/A" ratio = old / new return f"{ratio:.2f}x" print(f"\n {'Metric':<35} {'Old':>12} {'New':>12} {'Speedup':>10}") print(f" {'-' * 71}") print( f" {'Runtime (lazy, median)':<35} " f"{fmt_time(rt_old_lazy['median_s']):>12} " f"{fmt_time(rt_new_lazy['median_s']):>12} " f"{speedup(rt_old_lazy['median_s'], rt_new_lazy['median_s']):>10}" ) print( f" {'Runtime (materialized, median)':<35} " f"{fmt_time(rt_old_mat['median_s']):>12} " f"{fmt_time(rt_new_mat['median_s']):>12} " f"{speedup(rt_old_mat['median_s'], rt_new_mat['median_s']):>10}" ) print( f" {'Peak memory (lazy)':<35} " f"{fmt_bytes(mem_old_lazy['peak_bytes']):>12} " f"{fmt_bytes(mem_new_lazy['peak_bytes']):>12} " f"{mem_ratio(mem_old_lazy['peak_bytes'], mem_new_lazy['peak_bytes']):>10}" ) print( f" {'Peak memory (materialized)':<35} " f"{fmt_bytes(mem_old_mat['peak_bytes']):>12} " f"{fmt_bytes(mem_new_mat['peak_bytes']):>12} " f"{mem_ratio(mem_old_mat['peak_bytes'], mem_new_mat['peak_bytes']):>10}" ) print(f"\n{'=' * 90}") print("Done.") if __name__ == "__main__": main() ``` Pull Request resolved: #3246 Reviewed By: esantorella Differential Revision: D97354351 Pulled By: saitcakmak fbshipit-source-id: 47f0f6c916ac72b052c86e9d2ccac1281022cf74
1 parent b79646c commit 050f9ec

File tree

2 files changed

+37
-36
lines changed

2 files changed

+37
-36
lines changed

botorch/utils/multitask.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,22 @@ def separate_mtmvn(mvn: MultitaskMultivariateNormal) -> list[MultivariateNormal]
2727

2828
mvns = []
2929
for c in range(num_tasks):
30-
# Compute indices for task c's data points
3130
if mvn._interleaved:
3231
# For interleaved: task c data points are at positions
3332
# c, c+num_tasks, c+2*num_tasks, ...
33+
# Must use tensor indexing for strided access.
3434
task_indices = torch.arange(
3535
c, num_data * num_tasks, num_tasks, device=full_covar.device
3636
)
37+
task_covar = full_covar[..., task_indices, :]
38+
task_covar = task_covar[..., :, task_indices]
3739
else:
38-
# For non-interleaved: task c data points are at positions
39-
# c*num_data to (c+1)*num_data
40-
task_indices = torch.arange(
41-
c * num_data, (c + 1) * num_data, device=full_covar.device
42-
)
43-
44-
# Extract covariance submatrix for task c
45-
task_covar = full_covar[..., task_indices, :]
46-
task_covar = task_covar[..., :, task_indices]
40+
# For non-interleaved: task c data points are at contiguous positions
41+
# c*num_data to (c+1)*num_data. Use slice-based indexing which
42+
# LinearOperator handles more efficiently than tensor indexing.
43+
start = c * num_data
44+
end = start + num_data
45+
task_covar = full_covar[..., start:end, start:end]
4746

4847
mvns.append(
4948
MultivariateNormal(mvn.mean[..., c], to_linear_operator(task_covar))

test/utils/test_multitask.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,36 +13,38 @@
1313

1414

1515
class TestSeparateMTMVN(BotorchTestCase):
16-
def _test_separate_mtmvn(self, interleaved=False):
16+
def test_separate_mtmvn(self) -> None:
17+
for interleaved in (True, False):
18+
for batch_shape in (torch.Size([]), torch.Size([3])):
19+
with self.subTest(interleaved=interleaved, batch_shape=batch_shape):
20+
self._test_separate_mtmvn(
21+
interleaved=interleaved, batch_shape=batch_shape
22+
)
23+
24+
def _test_separate_mtmvn(self, interleaved: bool, batch_shape: torch.Size) -> None:
1725
for dtype in (torch.float, torch.double):
1826
tkwargs = {"device": self.device, "dtype": dtype}
19-
mean = torch.rand(2, 2, **tkwargs)
20-
a = torch.rand(4, 4, **tkwargs)
21-
covar = a @ a.transpose(-1, -2) + torch.eye(4, **tkwargs)
27+
num_data = 10
28+
num_tasks = 4
29+
n = num_data * num_tasks
30+
mean = torch.rand(*batch_shape, num_data, num_tasks, **tkwargs)
31+
a = torch.rand(*batch_shape, n, n, **tkwargs)
32+
covar = a @ a.transpose(-1, -2) + torch.eye(n, **tkwargs)
2233
mvn = MultitaskMultivariateNormal(
2334
mean=mean, covariance_matrix=covar, interleaved=interleaved
2435
)
2536
mtmvn_list = separate_mtmvn(mvn)
2637

27-
mean_1 = mean[..., 0]
28-
mean_2 = mean[..., 1]
29-
if interleaved:
30-
covar_1 = covar[::2, ::2]
31-
covar_2 = covar[1::2, 1::2]
32-
else:
33-
covar_1 = covar[:2, :2]
34-
covar_2 = covar[2:, 2:]
35-
36-
self.assertEqual(len(mtmvn_list), 2)
37-
for mvn_i, mean_i, covar_i in zip(
38-
mtmvn_list, (mean_1, mean_2), (covar_1, covar_2)
39-
):
40-
self.assertIsInstance(mvn_i, MultivariateNormal)
41-
self.assertTrue(torch.equal(mvn_i.mean, mean_i))
42-
self.assertAllClose(mvn_i.covariance_matrix, covar_i)
43-
44-
def test_separate_mtmvn_interleaved(self) -> None:
45-
self._test_separate_mtmvn(interleaved=True)
46-
47-
def test_separate_mtmvn_not_interleaved(self) -> None:
48-
self._test_separate_mtmvn(interleaved=False)
38+
self.assertEqual(len(mtmvn_list), num_tasks)
39+
dense_covar = covar.to_dense() if hasattr(covar, "to_dense") else covar
40+
41+
for c, mvn_c in enumerate(mtmvn_list):
42+
self.assertIsInstance(mvn_c, MultivariateNormal)
43+
self.assertEqual(mvn_c.mean.shape, (*batch_shape, num_data))
44+
self.assertTrue(torch.equal(mvn_c.mean, mean[..., c]))
45+
if interleaved:
46+
idx = torch.arange(c, n, num_tasks)
47+
else:
48+
idx = torch.arange(c * num_data, (c + 1) * num_data)
49+
expected_covar = dense_covar[..., idx, :][..., :, idx]
50+
self.assertAllClose(mvn_c.covariance_matrix, expected_covar, atol=1e-5)

0 commit comments

Comments
 (0)