4
4
import sys
5
5
from functools import partial , reduce
6
6
7
+ import pandas as pd
7
8
import pytest
8
9
10
+ pytest .importorskip ("memray" )
11
+
9
12
import cubed
10
13
import cubed .array_api as xp
11
14
import cubed .random
14
17
from cubed .core .optimization import multiple_inputs_optimize_dag
15
18
from cubed .diagnostics .history import HistoryCallback
16
19
from cubed .diagnostics .mem_warn import MemoryWarningCallback
20
+ from cubed .diagnostics .memray import MemrayCallback
17
21
from cubed .runtime .create import create_executor
18
22
from cubed .tests .utils import LITHOPS_LOCAL_CONFIG
19
23
24
+ pd .set_option ("display.max_columns" , None )
25
+
26
+
20
27
ALLOWED_MEM = 2_000_000_000
21
28
22
29
EXECUTORS = {}
@@ -107,15 +114,16 @@ def test_tril(tmp_path, spec, executor):
107
114
108
115
109
116
@pytest .mark .slow
110
- def test_add (tmp_path , spec , executor ):
117
+ @pytest .mark .parametrize ("optimize_graph" , [False , True ])
118
+ def test_add (tmp_path , spec , executor , optimize_graph ):
111
119
a = cubed .random .random (
112
120
(10000 , 10000 ), chunks = (5000 , 5000 ), spec = spec
113
121
) # 200MB chunks
114
122
b = cubed .random .random (
115
123
(10000 , 10000 ), chunks = (5000 , 5000 ), spec = spec
116
124
) # 200MB chunks
117
125
c = xp .add (a , b )
118
- run_operation (tmp_path , executor , "add" , c )
126
+ run_operation (tmp_path , executor , "add" , c , optimize_graph = optimize_graph )
119
127
120
128
121
129
@pytest .mark .slow
@@ -237,6 +245,16 @@ def test_concat(tmp_path, spec, executor):
237
245
run_operation (tmp_path , executor , "concat" , c )
238
246
239
247
248
+ @pytest .mark .slow
249
+ def test_flip (tmp_path , spec , executor ):
250
+ # Note 'a' has one fewer element in axis=0 to force chunking to cross array boundaries
251
+ a = cubed .random .random (
252
+ (9999 , 10000 ), chunks = (5000 , 5000 ), spec = spec
253
+ ) # 200MB chunks
254
+ b = xp .flip (a , axis = 0 )
255
+ run_operation (tmp_path , executor , "flip" , b )
256
+
257
+
240
258
@pytest .mark .slow
241
259
def test_reshape (tmp_path , spec , executor ):
242
260
a = cubed .random .random (
@@ -305,17 +323,27 @@ def test_sum_partial_reduce(tmp_path, spec, executor):
305
323
# Internal functions
306
324
307
325
308
- def run_operation (tmp_path , executor , name , result_array , * , optimize_function = None ):
309
- # result_array.visualize(f"cubed-{name}-unoptimized", optimize_graph=False)
326
+ def run_operation (
327
+ tmp_path ,
328
+ executor ,
329
+ name ,
330
+ result_array ,
331
+ * ,
332
+ optimize_graph = True ,
333
+ optimize_function = None ,
334
+ ):
335
+ # result_array.visualize(f"cubed-{name}-unoptimized", optimize_graph=False, show_hidden=True)
310
336
# result_array.visualize(f"cubed-{name}", optimize_function=optimize_function)
311
337
hist = HistoryCallback ()
312
338
mem_warn = MemoryWarningCallback ()
339
+ memray = MemrayCallback ()
313
340
# use store=None to write to temporary zarr
314
341
cubed .to_zarr (
315
342
result_array ,
316
343
store = None ,
317
344
executor = executor ,
318
- callbacks = [hist , mem_warn ],
345
+ callbacks = [hist , mem_warn , memray ],
346
+ optimize_graph = optimize_graph ,
319
347
optimize_function = optimize_function ,
320
348
)
321
349
@@ -328,6 +356,13 @@ def run_operation(tmp_path, executor, name, result_array, *, optimize_function=N
328
356
# check change in peak memory is no more than projected mem
329
357
assert (df ["peak_measured_mem_delta_mb_max" ] <= df ["projected_mem_mb" ]).all ()
330
358
359
+ # check memray peak memory allocated is no more than projected mem
360
+ for op_name , stats in memray .stats .items ():
361
+ assert (
362
+ stats .peak_memory_allocated
363
+ <= df .query (f"name=='{ op_name } '" )["projected_mem_mb" ].item () * 1_000_000
364
+ ), f"projected mem exceeds memray's peak allocated for { op_name } "
365
+
331
366
# check projected_mem_utilization does not exceed 1
332
367
# except on processes executor that runs multiple tasks in a process
333
368
if (
0 commit comments