Skip to content

Commit ad1d4f4

Browse files
authored
Merge pull request #2659 from devitocodes/estimate-memory
compiler: Add a utility to estimate memory usage for an operator
2 parents dd1e6c7 + 23428df commit ad1d4f4

File tree

8 files changed

+745
-127
lines changed

8 files changed

+745
-127
lines changed

.github/workflows/pytest-gpu.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
include:
4444
# -------------------- NVIDIA job --------------------
4545
- name: pytest-gpu-acc-nvidia
46-
test_files: "tests/test_adjoint.py tests/test_gpu_common.py tests/test_gpu_openacc.py"
46+
test_files: "tests/test_adjoint.py tests/test_gpu_common.py tests/test_gpu_openacc.py tests/test_operator.py::TestEstimateMemory"
4747
base: "devitocodes/bases:nvidia-nvc"
4848
runner_label: nvidiagpu
4949
test_drive_cmd: "nvidia-smi"
@@ -56,7 +56,7 @@ jobs:
5656
5757
# -------------------- AMD job -----------------------
5858
- name: pytest-gpu-omp-amd
59-
test_files: "tests/test_adjoint.py tests/test_gpu_common.py tests/test_gpu_openmp.py"
59+
test_files: "tests/test_adjoint.py tests/test_gpu_common.py tests/test_gpu_openmp.py tests/test_operator.py::TestEstimateMemory"
6060
runner_label: amdgpu
6161
base: "devitocodes/bases:amd"
6262
test_drive_cmd: "rocm-smi"

FAQ.md

Lines changed: 67 additions & 84 deletions
Large diffs are not rendered by default.

devito/operator/operator.py

Lines changed: 187 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from tempfile import gettempdir
88

99
from sympy import sympify
10+
import sympy
1011
import numpy as np
1112

1213
from devito.arch import ANYCPU, Device, compiler_registry, platform_registry
@@ -33,7 +34,7 @@
3334
from devito.tools import (DAG, OrderedSet, Signer, ReducerMap, as_mapper, as_tuple,
3435
flatten, filter_sorted, frozendict, is_integer,
3536
split, timed_pass, timed_region, contains_val,
36-
CacheInstances)
37+
CacheInstances, MemoryEstimate)
3738
from devito.types import (Buffer, Evaluable, host_layer, device_layer,
3839
disk_layer)
3940
from devito.types.dimension import Thickness
@@ -42,6 +43,9 @@
4243
__all__ = ['Operator']
4344

4445

46+
_layers = (disk_layer, host_layer, device_layer)
47+
48+
4549
class Operator(Callable):
4650

4751
"""
@@ -554,7 +558,7 @@ def _access_modes(self):
554558
return frozendict({i: AccessMode(i in self.reads, i in self.writes)
555559
for i in self.input})
556560

557-
def _prepare_arguments(self, autotune=None, **kwargs):
561+
def _prepare_arguments(self, autotune=None, estimate_memory=False, **kwargs):
558562
"""
559563
Process runtime arguments passed to ``.apply()` and derive
560564
default values for any remaining arguments.
@@ -602,6 +606,7 @@ def _prepare_arguments(self, autotune=None, **kwargs):
602606

603607
# Prepare to process data-carriers
604608
args = kwargs['args'] = ReducerMap()
609+
605610
kwargs['metadata'] = {'language': self._language,
606611
'platform': self._platform,
607612
'transients': self.transients,
@@ -611,7 +616,7 @@ def _prepare_arguments(self, autotune=None, **kwargs):
611616

612617
# Process data-carrier overrides
613618
for p in overrides:
614-
args.update(p._arg_values(**kwargs))
619+
args.update(p._arg_values(estimate_memory=estimate_memory, **kwargs))
615620
try:
616621
args.reduce_inplace()
617622
except ValueError:
@@ -625,7 +630,7 @@ def _prepare_arguments(self, autotune=None, **kwargs):
625630
if p.name in args:
626631
# E.g., SubFunctions
627632
continue
628-
for k, v in p._arg_values(**kwargs).items():
633+
for k, v in p._arg_values(estimate_memory=estimate_memory, **kwargs).items():
629634
if k not in args:
630635
args[k] = v
631636
elif k in futures:
@@ -653,6 +658,10 @@ def _prepare_arguments(self, autotune=None, **kwargs):
653658
# the subsequent phases of the arguments processing
654659
args = kwargs['args'] = ArgumentsMap(args, grid, self)
655660

661+
if estimate_memory:
662+
# No need to do anything more if only checking the memory
663+
return args
664+
656665
# Process Dimensions
657666
for d in reversed(toposort):
658667
args.update(d._arg_values(self._dspace[d], grid, **kwargs))
@@ -866,6 +875,53 @@ def cinterface(self, force=False):
866875
def __call__(self, **kwargs):
867876
return self.apply(**kwargs)
868877

878+
def estimate_memory(self, **kwargs):
879+
"""
880+
Estimate the memory consumed by the Operator without touching or allocating any
881+
data. This interface is designed to mimic `Operator.apply(**kwargs)` and can be
882+
called with the kwargs for a prospective Operator execution. With no arguments,
883+
it will simply estimate memory for the default Operator parameters. However, if
884+
desired, overrides can be supplied (as per `apply`) and these will be used for
885+
the memory estimate.
886+
887+
If estimating memory for an Operator which is expected to allocate large arrays,
888+
it is strongly recommended that one avoids touching the data in Python (thus
889+
avoiding allocation). `AbstractFunction` types have their data allocated lazily -
890+
the underlying array is only created at the point at which the `data`,
891+
`data_with_halo`, etc, attributes are first accessed. Thus by avoiding accessing
892+
such attributes in the memory estimation script, one can check the nominal memory
893+
usage of proposed Operators far larger than will fit in system DRAM.
894+
895+
Note that this estimate will build the Operator in order to factor in memory
896+
allocation for array temporaries and buffers generated during compilation.
897+
898+
Parameters
899+
----------
900+
**kwargs: dict
901+
As per `Operator.apply()`.
902+
903+
Returns
904+
-------
905+
summary: MemoryEstimate
906+
An estimate of memory consumed in each of the specified locations.
907+
"""
908+
# Build the arguments list for which to get the memory consumption
909+
# This is so that the estimate will factor in overrides
910+
args = self._prepare_arguments(estimate_memory=True, **kwargs)
911+
mem = args.nbytes_consumed
912+
913+
memreport = {'host': mem[host_layer], 'device': mem[device_layer]}
914+
915+
# Extra information for enriched Operators
916+
extras = self._enrich_memreport(args)
917+
memreport.update(extras)
918+
919+
return MemoryEstimate(memreport, name=self.name)
920+
921+
def _enrich_memreport(self, args):
922+
# Hook for enriching memory report with additional metadata
923+
return {}
924+
869925
def apply(self, **kwargs):
870926
"""
871927
Execute the Operator.
@@ -1283,6 +1339,41 @@ def saved_mapper(self):
12831339

12841340
return mapper
12851341

1342+
@cached_property
1343+
def _op_symbols(self):
1344+
"""Symbols in the Operator which may or may not carry data"""
1345+
return FindSymbols().visit(self.op)
1346+
1347+
@cached_property
1348+
def _op_functions(self):
1349+
"""Function symbols in the Operator"""
1350+
return [i for i in self._op_symbols if i.is_DiscreteFunction and not i.alias]
1351+
1352+
def _apply_override(self, i):
1353+
try:
1354+
return self.get(i.name, i)._obj
1355+
except AttributeError:
1356+
return self.get(i.name, i)
1357+
1358+
def _get_nbytes(self, i):
1359+
"""
1360+
Extract the allocated size of a symbol, accounting for any
1361+
overrides.
1362+
"""
1363+
obj = self._apply_override(i)
1364+
try:
1365+
# Non-regular AbstractFunction (compressed, etc)
1366+
nbytes = obj.nbytes_max
1367+
except AttributeError:
1368+
# Garden-variety AbstractFunction
1369+
nbytes = obj.nbytes
1370+
1371+
# Could nominally have symbolic nbytes at this point
1372+
if isinstance(nbytes, sympy.Basic):
1373+
return subs_op_args(nbytes, self)
1374+
1375+
return nbytes
1376+
12861377
@cached_property
12871378
def nbytes_avail_mapper(self):
12881379
"""
@@ -1307,9 +1398,69 @@ def nbytes_avail_mapper(self):
13071398
nproc = 1
13081399
mapper[host_layer] = int(ANYCPU.memavail() / nproc)
13091400

1401+
for layer in (host_layer, device_layer):
1402+
try:
1403+
mapper[layer] -= self.nbytes_consumed_operator.get(layer, 0)
1404+
except KeyError: # Might not have this layer in the mapper
1405+
pass
1406+
1407+
mapper = {k: int(v) for k, v in mapper.items()}
1408+
1409+
return mapper
1410+
1411+
@cached_property
1412+
def nbytes_consumed(self):
1413+
"""Memory consumed by all objects in the Operator"""
1414+
mem_locations = (
1415+
self.nbytes_consumed_functions,
1416+
self.nbytes_consumed_arrays,
1417+
self.nbytes_consumed_memmapped
1418+
)
1419+
return {layer: sum(loc[layer] for loc in mem_locations) for layer in _layers}
1420+
1421+
@cached_property
1422+
def nbytes_consumed_operator(self):
1423+
"""Memory consumed by objects allocated within the Operator"""
1424+
mem_locations = (
1425+
self.nbytes_consumed_arrays,
1426+
self.nbytes_consumed_memmapped
1427+
)
1428+
return {layer: sum(loc[layer] for loc in mem_locations) for layer in _layers}
1429+
1430+
@cached_property
1431+
def nbytes_consumed_functions(self):
1432+
"""
1433+
Memory consumed on both device and host by Functions in the
1434+
corresponding Operator.
1435+
"""
1436+
host = 0
1437+
device = 0
1438+
# Filter out arrays, aliases and non-AbstractFunction objects
1439+
for i in self._op_functions:
1440+
v = self._get_nbytes(i)
1441+
if i._mem_host or i._mem_mapped:
1442+
# No need to add to device , as it will be counted
1443+
# by nbytes_consumed_memmapped
1444+
host += v
1445+
elif i._mem_local:
1446+
if isinstance(self.platform, Device):
1447+
device += v
1448+
else:
1449+
host += v
1450+
1451+
return {disk_layer: 0, host_layer: host, device_layer: device}
1452+
1453+
@cached_property
1454+
def nbytes_consumed_arrays(self):
1455+
"""
1456+
Memory consumed on both device and host by C-land Arrays
1457+
in the corresponding Operator.
1458+
"""
1459+
host = 0
1460+
device = 0
13101461
# Temporaries such as Arrays are allocated and deallocated on-the-fly
13111462
# while in C land, so they need to be accounted for as well
1312-
for i in FindSymbols().visit(self.op):
1463+
for i in self._op_symbols:
13131464
if not i.is_Array or not i._mem_heap or i.alias:
13141465
continue
13151466

@@ -1323,17 +1474,26 @@ def nbytes_avail_mapper(self):
13231474
continue
13241475

13251476
if i._mem_host:
1326-
mapper[host_layer] -= v
1477+
host += v
13271478
elif i._mem_local:
13281479
if isinstance(self.platform, Device):
1329-
mapper[device_layer] -= v
1480+
device += v
13301481
else:
1331-
mapper[host_layer] -= v
1482+
host += v
13321483
elif i._mem_mapped:
13331484
if isinstance(self.platform, Device):
1334-
mapper[device_layer] -= v
1335-
mapper[host_layer] -= v
1485+
device += v
1486+
host += v
1487+
1488+
return {disk_layer: 0, host_layer: host, device_layer: device}
13361489

1490+
@cached_property
1491+
def nbytes_consumed_memmapped(self):
1492+
"""
1493+
Memory also consumed on device by data which is to be memcpy-d
1494+
from host to device at the start of computation.
1495+
"""
1496+
device = 0
13371497
# All input Functions are yet to be memcpy-ed to the device
13381498
# TODO: this may not be true depending on `devicerm`, which is however
13391499
# virtually never used
@@ -1343,17 +1503,27 @@ def nbytes_avail_mapper(self):
13431503
continue
13441504
try:
13451505
if i._mem_mapped:
1346-
try:
1347-
v = self[i.name]._obj.nbytes
1348-
except AttributeError:
1349-
v = i.nbytes
1350-
mapper[device_layer] -= v
1506+
device += self._get_nbytes(i)
13511507
except AttributeError:
13521508
pass
13531509

1354-
mapper = {k: int(v) for k, v in mapper.items()}
1510+
return {disk_layer: 0, host_layer: 0, device_layer: device}
13551511

1356-
return mapper
1512+
@cached_property
1513+
def nbytes_snapshots(self):
1514+
# Filter to streamed functions
1515+
disk = 0
1516+
for i in self._op_symbols:
1517+
try:
1518+
if i._child not in self._op_symbols:
1519+
# Use only the "innermost" layer to avoid counting snapshots
1520+
# twice
1521+
v = self._apply_override(i)
1522+
disk += v.size_snapshot*v._time_size_ideal*np.dtype(v.dtype).itemsize
1523+
except AttributeError:
1524+
pass
1525+
1526+
return {disk_layer: disk, host_layer: 0, device_layer: 0}
13571527

13581528

13591529
def parse_kwargs(**kwargs):

0 commit comments

Comments
 (0)