Commit 050f9ec
Use slice-based indexing in separate_mtmvn for non-interleaved case (#3246)
Summary:
For non-interleaved MultitaskMultivariateNormal, task covariance blocks are contiguous, so we can use slice indexing (start:end) instead of tensor indexing. LinearOperator handles slices much more efficiently, avoiding excessive CatLinearOperator creation that caused the performance regression from #2920.
The interleaved case retains tensor indexing since indices are strided.
Added more rigorous tests with larger data, more tasks, and batch dims.
Partially addresses #3095
## Profiling Results: `separate_mtmvn` (non-interleaved)
Old = tensor indexing, New = slice-based indexing
### Runtime (median)
| Config | Old | New | Speedup |
|---|---|---|---|
| 10 data, 3 tasks (30x30) | 174 us | 157 us | 1.1x |
| 50 data, 4 tasks (200x200) | 455 us | 138 us | 3.3x |
| 100 data, 5 tasks (500x500) | 890 us | 207 us | 4.3x |
| 200 data, 5 tasks (1000x1000) | 1.26 ms | 152 us | 8.3x |
| 50 data, 4 tasks, batch=3 (200x200) | 721 us | 148 us | 4.9x |
| 100 data, 5 tasks, batch=3 (500x500) | 1.01 ms | 179 us | 5.6x |
### Peak Memory
| Config | Old | New | Reduction |
|---|---|---|---|
| 10 data, 3 tasks (30x30) | 7.2 KB | 5.8 KB | 1.23x |
| 50 data, 4 tasks (200x200) | 8.5 KB | 6.8 KB | 1.25x |
| 100 data, 5 tasks (500x500) | 9.8 KB | 7.9 KB | 1.25x |
| 200 data, 5 tasks (1000x1000) | 9.8 KB | 7.7 KB | 1.26x |
| 50 data, 4 tasks, batch=3 (200x200) | 8.8 KB | 6.8 KB | 1.28x |
| 100 data, 5 tasks, batch=3 (500x500) | 10.2 KB | 8.0 KB | 1.28x |
Key takeaways:
1. Runtime: The slice-based indexing is significantly faster, with speedups ranging from 1.1x at small sizes to 8.3x at 1000x1000. The new implementation stays roughly constant (~150 us) regardless of matrix size since slicing is O(1) for LinearOperators, while the old tensor-indexing approach scales with matrix size.
2. Memory: Modest ~25% reduction in peak memory across all configurations. The memory savings are small in absolute terms because the profiling measures only the indexing overhead (the covariance matrix itself dominates memory in real usage).
3. Scaling: Slice-based indexing is O(1) for LinearOperators vs O(n) for tensor indexing, so the speedup grows with matrix size.
### Profiling script
```python
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Profile memory and runtime of separate_mtmvn: old (tensor indexing) vs new
(slice-based indexing) for the non-interleaved case.
Usage:
python scripts/profile_separate_mtmvn.py
"""
from __future__ import annotations
import gc
import time
import tracemalloc
from contextlib import contextmanager
from typing import Callable
import torch
from gpytorch.distributions import MultitaskMultivariateNormal
from gpytorch.distributions.multivariate_normal import MultivariateNormal
from linear_operator import to_linear_operator
# ---------------------------------------------------------------------------
# Old implementation (tensor indexing for both branches)
# ---------------------------------------------------------------------------
def separate_mtmvn_old(
mvn: MultitaskMultivariateNormal,
) -> list[MultivariateNormal]:
full_covar = mvn.lazy_covariance_matrix
num_data, num_tasks = mvn.mean.shape[-2:]
mvns = []
for c in range(num_tasks):
if mvn._interleaved:
task_indices = torch.arange(
c, num_data * num_tasks, num_tasks, device=full_covar.device
)
else:
task_indices = torch.arange(
c * num_data, (c + 1) * num_data, device=full_covar.device
)
task_covar = full_covar[..., task_indices, :]
task_covar = task_covar[..., :, task_indices]
mvns.append(
MultivariateNormal(mvn.mean[..., c], to_linear_operator(task_covar))
)
return mvns
# ---------------------------------------------------------------------------
# New implementation (slice-based indexing for non-interleaved)
# ---------------------------------------------------------------------------
def separate_mtmvn_new(
mvn: MultitaskMultivariateNormal,
) -> list[MultivariateNormal]:
full_covar = mvn.lazy_covariance_matrix
num_data, num_tasks = mvn.mean.shape[-2:]
mvns = []
for c in range(num_tasks):
if mvn._interleaved:
task_indices = torch.arange(
c, num_data * num_tasks, num_tasks, device=full_covar.device
)
task_covar = full_covar[..., task_indices, :]
task_covar = task_covar[..., :, task_indices]
else:
start = c * num_data
end = start + num_data
task_covar = full_covar[..., start:end, start:end]
mvns.append(
MultivariateNormal(mvn.mean[..., c], to_linear_operator(task_covar))
)
return mvns
# ---------------------------------------------------------------------------
# Profiling helpers
# ---------------------------------------------------------------------------
contextmanager
def track_memory():
"""Context manager that returns peak memory usage in bytes."""
gc.collect()
tracemalloc.start()
snapshot_start = tracemalloc.take_snapshot()
result = {"peak_bytes": 0, "current_bytes": 0}
try:
yield result
finally:
_, peak = tracemalloc.get_traced_memory()
current = tracemalloc.get_traced_memory()[0]
result["peak_bytes"] = peak
result["current_bytes"] = current
tracemalloc.stop()
def benchmark_runtime(
fn: Callable,
mvn: MultitaskMultivariateNormal,
warmup: int = 3,
repeats: int = 10,
materialize: bool = False,
) -> dict[str, float]:
"""Time `fn(mvn)` and optionally materializing the covariances."""
# Warmup
for _ in range(warmup):
result = fn(mvn)
if materialize:
for m in result:
m.covariance_matrix # noqa: B018 -- force evaluation
times = []
for _ in range(repeats):
gc.collect()
t0 = time.perf_counter()
result = fn(mvn)
if materialize:
for m in result:
m.covariance_matrix # noqa: B018
t1 = time.perf_counter()
times.append(t1 - t0)
return {
"mean_s": sum(times) / len(times),
"min_s": min(times),
"max_s": max(times),
"median_s": sorted(times)[len(times) // 2],
}
def benchmark_memory(
fn: Callable,
mvn: MultitaskMultivariateNormal,
materialize: bool = False,
) -> dict[str, int]:
"""Measure peak memory of `fn(mvn)` including optional materialization."""
gc.collect()
with track_memory() as mem:
result = fn(mvn)
if materialize:
for m in result:
m.covariance_matrix # noqa: B018
return mem
def build_mtmvn(
num_data: int,
num_tasks: int,
batch_shape: torch.Size = torch.Size([]),
interleaved: bool = False,
dtype: torch.dtype = torch.float64,
) -> MultitaskMultivariateNormal:
"""Build a MultitaskMultivariateNormal for testing."""
n = num_data * num_tasks
mean = torch.rand(*batch_shape, num_data, num_tasks, dtype=dtype)
a = torch.rand(*batch_shape, n, n, dtype=dtype)
covar = a @ a.transpose(-1, -2) + torch.eye(n, dtype=dtype)
return MultitaskMultivariateNormal(
mean=mean, covariance_matrix=covar, interleaved=interleaved
)
def fmt_bytes(b: int) -> str:
if b < 1024:
return f"{b} B"
elif b < 1024**2:
return f"{b / 1024:.1f} KB"
else:
return f"{b / 1024**2:.2f} MB"
def fmt_time(s: float) -> str:
if s < 1e-3:
return f"{s * 1e6:.1f} us"
elif s < 1:
return f"{s * 1e3:.2f} ms"
else:
return f"{s:.3f} s"
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
configs = [
# (num_data, num_tasks, batch_shape)
(10, 3, torch.Size([])),
(50, 4, torch.Size([])),
(100, 5, torch.Size([])),
(200, 5, torch.Size([])),
(50, 4, torch.Size([3])),
(100, 5, torch.Size([3])),
]
print("=" * 90)
print("Profile: separate_mtmvn -- old (tensor index) vs new (slice index)")
print("Non-interleaved case only (the interleaved path is unchanged)")
print("=" * 90)
for num_data, num_tasks, batch_shape in configs:
n = num_data * num_tasks
label = (
f"num_data={num_data}, num_tasks={num_tasks}, "
f"batch={list(batch_shape)}, matrix_size={n}x{n}"
)
print(f"\n{'─' * 90}")
print(f" {label}")
print(f"{'─' * 90}")
mvn = build_mtmvn(
num_data=num_data,
num_tasks=num_tasks,
batch_shape=batch_shape,
interleaved=False,
)
# -- Runtime without materialization (lazy) --
rt_old_lazy = benchmark_runtime(separate_mtmvn_old, mvn, materialize=False)
rt_new_lazy = benchmark_runtime(separate_mtmvn_new, mvn, materialize=False)
# -- Runtime with materialization --
rt_old_mat = benchmark_runtime(separate_mtmvn_old, mvn, materialize=True)
rt_new_mat = benchmark_runtime(separate_mtmvn_new, mvn, materialize=True)
# -- Memory without materialization --
mem_old_lazy = benchmark_memory(separate_mtmvn_old, mvn, materialize=False)
mem_new_lazy = benchmark_memory(separate_mtmvn_new, mvn, materialize=False)
# -- Memory with materialization --
mem_old_mat = benchmark_memory(separate_mtmvn_old, mvn, materialize=True)
mem_new_mat = benchmark_memory(separate_mtmvn_new, mvn, materialize=True)
def speedup(old: float, new: float) -> str:
if new == 0:
return "inf"
ratio = old / new
return f"{ratio:.2f}x"
def mem_ratio(old: int, new: int) -> str:
if new == 0:
return "N/A"
ratio = old / new
return f"{ratio:.2f}x"
print(f"\n {'Metric':<35} {'Old':>12} {'New':>12} {'Speedup':>10}")
print(f" {'-' * 71}")
print(
f" {'Runtime (lazy, median)':<35} "
f"{fmt_time(rt_old_lazy['median_s']):>12} "
f"{fmt_time(rt_new_lazy['median_s']):>12} "
f"{speedup(rt_old_lazy['median_s'], rt_new_lazy['median_s']):>10}"
)
print(
f" {'Runtime (materialized, median)':<35} "
f"{fmt_time(rt_old_mat['median_s']):>12} "
f"{fmt_time(rt_new_mat['median_s']):>12} "
f"{speedup(rt_old_mat['median_s'], rt_new_mat['median_s']):>10}"
)
print(
f" {'Peak memory (lazy)':<35} "
f"{fmt_bytes(mem_old_lazy['peak_bytes']):>12} "
f"{fmt_bytes(mem_new_lazy['peak_bytes']):>12} "
f"{mem_ratio(mem_old_lazy['peak_bytes'], mem_new_lazy['peak_bytes']):>10}"
)
print(
f" {'Peak memory (materialized)':<35} "
f"{fmt_bytes(mem_old_mat['peak_bytes']):>12} "
f"{fmt_bytes(mem_new_mat['peak_bytes']):>12} "
f"{mem_ratio(mem_old_mat['peak_bytes'], mem_new_mat['peak_bytes']):>10}"
)
print(f"\n{'=' * 90}")
print("Done.")
if __name__ == "__main__":
main()
```
Pull Request resolved: #3246
Reviewed By: esantorella
Differential Revision: D97354351
Pulled By: saitcakmak
fbshipit-source-id: 47f0f6c916ac72b052c86e9d2ccac1281022cf741 parent b79646c commit 050f9ec
2 files changed
+37
-36
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
27 | 27 | | |
28 | 28 | | |
29 | 29 | | |
30 | | - | |
31 | 30 | | |
32 | 31 | | |
33 | 32 | | |
| 33 | + | |
34 | 34 | | |
35 | 35 | | |
36 | 36 | | |
| 37 | + | |
| 38 | + | |
37 | 39 | | |
38 | | - | |
39 | | - | |
40 | | - | |
41 | | - | |
42 | | - | |
43 | | - | |
44 | | - | |
45 | | - | |
46 | | - | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
47 | 46 | | |
48 | 47 | | |
49 | 48 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
13 | 13 | | |
14 | 14 | | |
15 | 15 | | |
16 | | - | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
17 | 25 | | |
18 | 26 | | |
19 | | - | |
20 | | - | |
21 | | - | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
22 | 33 | | |
23 | 34 | | |
24 | 35 | | |
25 | 36 | | |
26 | 37 | | |
27 | | - | |
28 | | - | |
29 | | - | |
30 | | - | |
31 | | - | |
32 | | - | |
33 | | - | |
34 | | - | |
35 | | - | |
36 | | - | |
37 | | - | |
38 | | - | |
39 | | - | |
40 | | - | |
41 | | - | |
42 | | - | |
43 | | - | |
44 | | - | |
45 | | - | |
46 | | - | |
47 | | - | |
48 | | - | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
0 commit comments