77from tempfile import gettempdir
88
99from sympy import sympify
10+ import sympy
1011import numpy as np
1112
1213from devito .arch import ANYCPU , Device , compiler_registry , platform_registry
3334from devito .tools import (DAG , OrderedSet , Signer , ReducerMap , as_mapper , as_tuple ,
3435 flatten , filter_sorted , frozendict , is_integer ,
3536 split , timed_pass , timed_region , contains_val ,
36- CacheInstances )
37+ CacheInstances , MemoryEstimate )
3738from devito .types import (Buffer , Evaluable , host_layer , device_layer ,
3839 disk_layer )
3940from devito .types .dimension import Thickness
4243__all__ = ['Operator' ]
4344
4445
46+ _layers = (disk_layer , host_layer , device_layer )
47+
48+
4549class Operator (Callable ):
4650
4751 """
@@ -554,7 +558,7 @@ def _access_modes(self):
554558 return frozendict ({i : AccessMode (i in self .reads , i in self .writes )
555559 for i in self .input })
556560
557- def _prepare_arguments (self , autotune = None , ** kwargs ):
561+ def _prepare_arguments (self , autotune = None , estimate_memory = False , ** kwargs ):
558562 """
559563 Process runtime arguments passed to ``.apply()` and derive
560564 default values for any remaining arguments.
@@ -602,6 +606,7 @@ def _prepare_arguments(self, autotune=None, **kwargs):
602606
603607 # Prepare to process data-carriers
604608 args = kwargs ['args' ] = ReducerMap ()
609+
605610 kwargs ['metadata' ] = {'language' : self ._language ,
606611 'platform' : self ._platform ,
607612 'transients' : self .transients ,
@@ -611,7 +616,7 @@ def _prepare_arguments(self, autotune=None, **kwargs):
611616
612617 # Process data-carrier overrides
613618 for p in overrides :
614- args .update (p ._arg_values (** kwargs ))
619+ args .update (p ._arg_values (estimate_memory = estimate_memory , ** kwargs ))
615620 try :
616621 args .reduce_inplace ()
617622 except ValueError :
@@ -625,7 +630,7 @@ def _prepare_arguments(self, autotune=None, **kwargs):
625630 if p .name in args :
626631 # E.g., SubFunctions
627632 continue
628- for k , v in p ._arg_values (** kwargs ).items ():
633+ for k , v in p ._arg_values (estimate_memory = estimate_memory , ** kwargs ).items ():
629634 if k not in args :
630635 args [k ] = v
631636 elif k in futures :
@@ -653,6 +658,10 @@ def _prepare_arguments(self, autotune=None, **kwargs):
653658 # the subsequent phases of the arguments processing
654659 args = kwargs ['args' ] = ArgumentsMap (args , grid , self )
655660
661+ if estimate_memory :
662+ # No need to do anything more if only checking the memory
663+ return args
664+
656665 # Process Dimensions
657666 for d in reversed (toposort ):
658667 args .update (d ._arg_values (self ._dspace [d ], grid , ** kwargs ))
@@ -866,6 +875,53 @@ def cinterface(self, force=False):
866875 def __call__ (self , ** kwargs ):
867876 return self .apply (** kwargs )
868877
878+ def estimate_memory (self , ** kwargs ):
879+ """
880+ Estimate the memory consumed by the Operator without touching or allocating any
881+ data. This interface is designed to mimic `Operator.apply(**kwargs)` and can be
882+ called with the kwargs for a prospective Operator execution. With no arguments,
883+ it will simply estimate memory for the default Operator parameters. However, if
884+ desired, overrides can be supplied (as per `apply`) and these will be used for
885+ the memory estimate.
886+
887+ If estimating memory for an Operator which is expected to allocate large arrays,
888+ it is strongly recommended that one avoids touching the data in Python (thus
889+ avoiding allocation). `AbstractFunction` types have their data allocated lazily -
890+ the underlying array is only created at the point at which the `data`,
891+ `data_with_halo`, etc, attributes are first accessed. Thus by avoiding accessing
892+ such attributes in the memory estimation script, one can check the nominal memory
893+ usage of proposed Operators far larger than will fit in system DRAM.
894+
895+ Note that this estimate will build the Operator in order to factor in memory
896+ allocation for array temporaries and buffers generated during compilation.
897+
898+ Parameters
899+ ----------
900+ **kwargs: dict
901+ As per `Operator.apply()`.
902+
903+ Returns
904+ -------
905+ summary: MemoryEstimate
906+ An estimate of memory consumed in each of the specified locations.
907+ """
908+ # Build the arguments list for which to get the memory consumption
909+ # This is so that the estimate will factor in overrides
910+ args = self ._prepare_arguments (estimate_memory = True , ** kwargs )
911+ mem = args .nbytes_consumed
912+
913+ memreport = {'host' : mem [host_layer ], 'device' : mem [device_layer ]}
914+
915+ # Extra information for enriched Operators
916+ extras = self ._enrich_memreport (args )
917+ memreport .update (extras )
918+
919+ return MemoryEstimate (memreport , name = self .name )
920+
921+ def _enrich_memreport (self , args ):
922+ # Hook for enriching memory report with additional metadata
923+ return {}
924+
869925 def apply (self , ** kwargs ):
870926 """
871927 Execute the Operator.
@@ -1283,6 +1339,41 @@ def saved_mapper(self):
12831339
12841340 return mapper
12851341
1342+ @cached_property
1343+ def _op_symbols (self ):
1344+ """Symbols in the Operator which may or may not carry data"""
1345+ return FindSymbols ().visit (self .op )
1346+
1347+ @cached_property
1348+ def _op_functions (self ):
1349+ """Function symbols in the Operator"""
1350+ return [i for i in self ._op_symbols if i .is_DiscreteFunction and not i .alias ]
1351+
1352+ def _apply_override (self , i ):
1353+ try :
1354+ return self .get (i .name , i )._obj
1355+ except AttributeError :
1356+ return self .get (i .name , i )
1357+
1358+ def _get_nbytes (self , i ):
1359+ """
1360+ Extract the allocated size of a symbol, accounting for any
1361+ overrides.
1362+ """
1363+ obj = self ._apply_override (i )
1364+ try :
1365+ # Non-regular AbstractFunction (compressed, etc)
1366+ nbytes = obj .nbytes_max
1367+ except AttributeError :
1368+ # Garden-variety AbstractFunction
1369+ nbytes = obj .nbytes
1370+
1371+ # Could nominally have symbolic nbytes at this point
1372+ if isinstance (nbytes , sympy .Basic ):
1373+ return subs_op_args (nbytes , self )
1374+
1375+ return nbytes
1376+
12861377 @cached_property
12871378 def nbytes_avail_mapper (self ):
12881379 """
@@ -1307,9 +1398,69 @@ def nbytes_avail_mapper(self):
13071398 nproc = 1
13081399 mapper [host_layer ] = int (ANYCPU .memavail () / nproc )
13091400
1401+ for layer in (host_layer , device_layer ):
1402+ try :
1403+ mapper [layer ] -= self .nbytes_consumed_operator .get (layer , 0 )
1404+ except KeyError : # Might not have this layer in the mapper
1405+ pass
1406+
1407+ mapper = {k : int (v ) for k , v in mapper .items ()}
1408+
1409+ return mapper
1410+
1411+ @cached_property
1412+ def nbytes_consumed (self ):
1413+ """Memory consumed by all objects in the Operator"""
1414+ mem_locations = (
1415+ self .nbytes_consumed_functions ,
1416+ self .nbytes_consumed_arrays ,
1417+ self .nbytes_consumed_memmapped
1418+ )
1419+ return {layer : sum (loc [layer ] for loc in mem_locations ) for layer in _layers }
1420+
1421+ @cached_property
1422+ def nbytes_consumed_operator (self ):
1423+ """Memory consumed by objects allocated within the Operator"""
1424+ mem_locations = (
1425+ self .nbytes_consumed_arrays ,
1426+ self .nbytes_consumed_memmapped
1427+ )
1428+ return {layer : sum (loc [layer ] for loc in mem_locations ) for layer in _layers }
1429+
1430+ @cached_property
1431+ def nbytes_consumed_functions (self ):
1432+ """
1433+ Memory consumed on both device and host by Functions in the
1434+ corresponding Operator.
1435+ """
1436+ host = 0
1437+ device = 0
1438+ # Filter out arrays, aliases and non-AbstractFunction objects
1439+ for i in self ._op_functions :
1440+ v = self ._get_nbytes (i )
1441+ if i ._mem_host or i ._mem_mapped :
1442+ # No need to add to device , as it will be counted
1443+ # by nbytes_consumed_memmapped
1444+ host += v
1445+ elif i ._mem_local :
1446+ if isinstance (self .platform , Device ):
1447+ device += v
1448+ else :
1449+ host += v
1450+
1451+ return {disk_layer : 0 , host_layer : host , device_layer : device }
1452+
1453+ @cached_property
1454+ def nbytes_consumed_arrays (self ):
1455+ """
1456+ Memory consumed on both device and host by C-land Arrays
1457+ in the corresponding Operator.
1458+ """
1459+ host = 0
1460+ device = 0
13101461 # Temporaries such as Arrays are allocated and deallocated on-the-fly
13111462 # while in C land, so they need to be accounted for as well
1312- for i in FindSymbols (). visit ( self .op ) :
1463+ for i in self ._op_symbols :
13131464 if not i .is_Array or not i ._mem_heap or i .alias :
13141465 continue
13151466
@@ -1323,17 +1474,26 @@ def nbytes_avail_mapper(self):
13231474 continue
13241475
13251476 if i ._mem_host :
1326- mapper [ host_layer ] - = v
1477+ host + = v
13271478 elif i ._mem_local :
13281479 if isinstance (self .platform , Device ):
1329- mapper [ device_layer ] - = v
1480+ device + = v
13301481 else :
1331- mapper [ host_layer ] - = v
1482+ host + = v
13321483 elif i ._mem_mapped :
13331484 if isinstance (self .platform , Device ):
1334- mapper [device_layer ] -= v
1335- mapper [host_layer ] -= v
1485+ device += v
1486+ host += v
1487+
1488+ return {disk_layer : 0 , host_layer : host , device_layer : device }
13361489
1490+ @cached_property
1491+ def nbytes_consumed_memmapped (self ):
1492+ """
1493+ Memory also consumed on device by data which is to be memcpy-d
1494+ from host to device at the start of computation.
1495+ """
1496+ device = 0
13371497 # All input Functions are yet to be memcpy-ed to the device
13381498 # TODO: this may not be true depending on `devicerm`, which is however
13391499 # virtually never used
@@ -1343,17 +1503,27 @@ def nbytes_avail_mapper(self):
13431503 continue
13441504 try :
13451505 if i ._mem_mapped :
1346- try :
1347- v = self [i .name ]._obj .nbytes
1348- except AttributeError :
1349- v = i .nbytes
1350- mapper [device_layer ] -= v
1506+ device += self ._get_nbytes (i )
13511507 except AttributeError :
13521508 pass
13531509
1354- mapper = { k : int ( v ) for k , v in mapper . items () }
1510+ return { disk_layer : 0 , host_layer : 0 , device_layer : device }
13551511
1356- return mapper
1512+ @cached_property
1513+ def nbytes_snapshots (self ):
1514+ # Filter to streamed functions
1515+ disk = 0
1516+ for i in self ._op_symbols :
1517+ try :
1518+ if i ._child not in self ._op_symbols :
1519+ # Use only the "innermost" layer to avoid counting snapshots
1520+ # twice
1521+ v = self ._apply_override (i )
1522+ disk += v .size_snapshot * v ._time_size_ideal * np .dtype (v .dtype ).itemsize
1523+ except AttributeError :
1524+ pass
1525+
1526+ return {disk_layer : disk , host_layer : 0 , device_layer : 0 }
13571527
13581528
13591529def parse_kwargs (** kwargs ):
0 commit comments