Skip to content

Commit d2ba5e1

Browse files
authored
Memray callback (#561)
* Add memray callback to print large MALLOC and FREE calls for each operation * Use memray callback in test_mem_utilization * Add test for flip in test_mem_utilization * Skip test_mem_utilization if memray isn't installed * Fix mypy
1 parent 680b3c1 commit d2ba5e1

File tree

3 files changed

+141
-6
lines changed

3 files changed

+141
-6
lines changed

.github/workflows/slow-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
4141
- name: Install
4242
run: |
43-
python -m pip install -e .[test]
43+
python -m pip install -e .[test] memray
4444
4545
- name: Run tests
4646
run: |

cubed/diagnostics/memray.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from dataclasses import dataclass
2+
from enum import Enum
3+
from pathlib import Path
4+
from typing import Dict, Optional
5+
6+
import memray
7+
from memray._memray import compute_statistics
8+
from memray._stats import Stats
9+
10+
from cubed.runtime.pipeline import visit_nodes
11+
from cubed.runtime.types import Callback
12+
13+
14+
class AllocationType(Enum):
15+
MALLOC = 1
16+
FREE = 2
17+
CALLOC = 3
18+
REALLOC = 4
19+
20+
21+
@dataclass()
22+
class Allocation:
23+
object_id: str
24+
allocation_type: AllocationType
25+
memory: int
26+
address: Optional[int] = None
27+
call: Optional[str] = None
28+
29+
def __repr__(self) -> str:
30+
return f"{self.object_id} {self.allocation_type.name} {self.memory or ''} {self.address or ''} {self.call or ''}"
31+
32+
33+
class MemrayCallback(Callback):
34+
"""Process Memray results for a computation, and print large MALLOC and FREE calls for each operation."""
35+
36+
def __init__(self, mem_threshold=50_000_000) -> None:
37+
self.mem_threshold = mem_threshold
38+
self.allocations: Dict[str, Allocation] = {}
39+
self.stats: Dict[str, Stats] = {}
40+
41+
def on_compute_end(self, event):
42+
for name, _ in visit_nodes(event.dag):
43+
memray_result_file = f"history/{event.compute_id}/memray/{name}.bin"
44+
if not Path(memray_result_file).is_file():
45+
continue
46+
47+
allocations = get_allocations_over_threshold(
48+
memray_result_file, self.mem_threshold
49+
)
50+
51+
print(memray_result_file)
52+
for allocation in allocations:
53+
print(allocation)
54+
55+
stats = compute_statistics(memray_result_file)
56+
print(f"Peak memory allocated: {stats.peak_memory_allocated}")
57+
58+
print()
59+
60+
self.allocations[name] = allocations
61+
self.stats[name] = stats
62+
63+
64+
def get_allocations_over_threshold(result_file, mem_threshold):
65+
# find all allocations over threshold and their corresponding free operations
66+
id = 0
67+
address_to_allocation = {}
68+
with memray.FileReader(result_file) as reader:
69+
for a in reader.get_allocation_records():
70+
if a.size >= mem_threshold:
71+
func, mod, line = a.stack_trace()[0]
72+
if a.allocator == memray.AllocatorType.MALLOC:
73+
allocation_type = AllocationType.MALLOC
74+
elif a.allocator == memray.AllocatorType.CALLOC:
75+
allocation_type = AllocationType.CALLOC
76+
elif a.allocator == memray.AllocatorType.REALLOC:
77+
allocation_type = AllocationType.REALLOC
78+
else:
79+
raise ValueError(f"Unsupported memray.AllocatorType {a.allocator}")
80+
allocation = Allocation(
81+
f"object-{id:03}",
82+
allocation_type,
83+
a.size,
84+
address=a.address,
85+
call=f"{func};{mod};{line}",
86+
)
87+
id += 1
88+
address_to_allocation[a.address] = allocation
89+
yield allocation
90+
elif (
91+
a.allocator == memray.AllocatorType.FREE
92+
and a.address in address_to_allocation
93+
):
94+
allocation = address_to_allocation.pop(a.address)
95+
yield Allocation(
96+
allocation.object_id,
97+
AllocationType.FREE,
98+
allocation.memory,
99+
address=a.address,
100+
)

cubed/tests/test_mem_utilization.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,11 @@
44
import sys
55
from functools import partial, reduce
66

7+
import pandas as pd
78
import pytest
89

10+
pytest.importorskip("memray")
11+
912
import cubed
1013
import cubed.array_api as xp
1114
import cubed.random
@@ -14,9 +17,13 @@
1417
from cubed.core.optimization import multiple_inputs_optimize_dag
1518
from cubed.diagnostics.history import HistoryCallback
1619
from cubed.diagnostics.mem_warn import MemoryWarningCallback
20+
from cubed.diagnostics.memray import MemrayCallback
1721
from cubed.runtime.create import create_executor
1822
from cubed.tests.utils import LITHOPS_LOCAL_CONFIG
1923

24+
pd.set_option("display.max_columns", None)
25+
26+
2027
ALLOWED_MEM = 2_000_000_000
2128

2229
EXECUTORS = {}
@@ -107,15 +114,16 @@ def test_tril(tmp_path, spec, executor):
107114

108115

109116
@pytest.mark.slow
110-
def test_add(tmp_path, spec, executor):
117+
@pytest.mark.parametrize("optimize_graph", [False, True])
118+
def test_add(tmp_path, spec, executor, optimize_graph):
111119
a = cubed.random.random(
112120
(10000, 10000), chunks=(5000, 5000), spec=spec
113121
) # 200MB chunks
114122
b = cubed.random.random(
115123
(10000, 10000), chunks=(5000, 5000), spec=spec
116124
) # 200MB chunks
117125
c = xp.add(a, b)
118-
run_operation(tmp_path, executor, "add", c)
126+
run_operation(tmp_path, executor, "add", c, optimize_graph=optimize_graph)
119127

120128

121129
@pytest.mark.slow
@@ -237,6 +245,16 @@ def test_concat(tmp_path, spec, executor):
237245
run_operation(tmp_path, executor, "concat", c)
238246

239247

248+
@pytest.mark.slow
249+
def test_flip(tmp_path, spec, executor):
250+
# Note 'a' has one fewer element in axis=0 to force chunking to cross array boundaries
251+
a = cubed.random.random(
252+
(9999, 10000), chunks=(5000, 5000), spec=spec
253+
) # 200MB chunks
254+
b = xp.flip(a, axis=0)
255+
run_operation(tmp_path, executor, "flip", b)
256+
257+
240258
@pytest.mark.slow
241259
def test_reshape(tmp_path, spec, executor):
242260
a = cubed.random.random(
@@ -305,17 +323,27 @@ def test_sum_partial_reduce(tmp_path, spec, executor):
305323
# Internal functions
306324

307325

308-
def run_operation(tmp_path, executor, name, result_array, *, optimize_function=None):
309-
# result_array.visualize(f"cubed-{name}-unoptimized", optimize_graph=False)
326+
def run_operation(
327+
tmp_path,
328+
executor,
329+
name,
330+
result_array,
331+
*,
332+
optimize_graph=True,
333+
optimize_function=None,
334+
):
335+
# result_array.visualize(f"cubed-{name}-unoptimized", optimize_graph=False, show_hidden=True)
310336
# result_array.visualize(f"cubed-{name}", optimize_function=optimize_function)
311337
hist = HistoryCallback()
312338
mem_warn = MemoryWarningCallback()
339+
memray = MemrayCallback()
313340
# use store=None to write to temporary zarr
314341
cubed.to_zarr(
315342
result_array,
316343
store=None,
317344
executor=executor,
318-
callbacks=[hist, mem_warn],
345+
callbacks=[hist, mem_warn, memray],
346+
optimize_graph=optimize_graph,
319347
optimize_function=optimize_function,
320348
)
321349

@@ -328,6 +356,13 @@ def run_operation(tmp_path, executor, name, result_array, *, optimize_function=N
328356
# check change in peak memory is no more than projected mem
329357
assert (df["peak_measured_mem_delta_mb_max"] <= df["projected_mem_mb"]).all()
330358

359+
# check memray peak memory allocated is no more than projected mem
360+
for op_name, stats in memray.stats.items():
361+
assert (
362+
stats.peak_memory_allocated
363+
<= df.query(f"name=='{op_name}'")["projected_mem_mb"].item() * 1_000_000
364+
), f"projected mem exceeds memray's peak allocated for {op_name}"
365+
331366
# check projected_mem_utilization does not exceed 1
332367
# except on processes executor that runs multiple tasks in a process
333368
if (

0 commit comments

Comments
 (0)