2626from devito .types .args import ArgProvider
2727from devito .types .caching import CacheManager
2828from devito .types .basic import AbstractFunction , Size
29- from devito .types .utils import Buffer , DimensionTuple , NODE , CELL , host_layer
29+ from devito .types .utils import Buffer , DimensionTuple , NODE , CELL , host_layer , Staggering
3030
3131__all__ = ['Function' , 'TimeFunction' , 'SubFunction' , 'TempFunction' ]
3232
@@ -1010,6 +1010,10 @@ def _cache_meta(self):
10101010 def __init_finalize__ (self , * args , ** kwargs ):
10111011 super ().__init_finalize__ (* args , ** kwargs )
10121012
1013+ # Staggering
1014+ self ._staggered = self .__staggered_setup__ (self .dimensions ,
1015+ staggered = kwargs .get ('staggered' ))
1016+
10131017 # Space order
10141018 space_order = kwargs .get ('space_order' , 1 )
10151019 if isinstance (space_order , int ):
@@ -1042,7 +1046,7 @@ def __fd_setup__(self):
10421046
10431047 @cached_property
10441048 def _fd_priority (self ):
1045- return 1 if self .staggered in [ NODE , None ] else 2
1049+ return 1 if self .staggered . on_node else 2
10461050
10471051 @property
10481052 def is_parameter (self ):
@@ -1059,26 +1063,33 @@ def _eval_at(self, func):
10591063 return self
10601064
10611065 @classmethod
1062- def __staggered_setup__ (cls , dimensions , ** kwargs ):
1066+ def __staggered_setup__ (cls , dimensions , staggered = None , ** kwargs ):
10631067 """
10641068 Setup staggering-related metadata. This method assigns:
10651069
10661070 * 0 to non-staggered dimensions;
10671071 * 1 to staggered dimensions.
10681072 """
1069- stagg = kwargs .get ('staggered' , None )
1070- if stagg is CELL :
1071- staggered = (sympy .S .One for d in dimensions )
1072- elif stagg in [None , NODE ]:
1073- staggered = (sympy .S .Zero for d in dimensions )
1074- elif all (is_integer (s ) for s in as_tuple (stagg )):
1073+ if not staggered :
1074+ processed = ()
1075+ elif staggered is CELL :
1076+ processed = (sympy .S .One ,)* len (dimensions )
1077+ elif staggered is NODE :
1078+ processed = (sympy .S .Zero ,)* len (dimensions )
1079+ elif all (is_integer (s ) for s in as_tuple (staggered )):
10751080 # Staggering is already a tuple likely from rebuild
1076- assert len (stagg ) == len (dimensions )
1077- return tuple ( stagg )
1081+ assert len (staggered ) == len (dimensions )
1082+ processed = staggered
10781083 else :
1079- staggered = (sympy .S .One if d in as_tuple (stagg ) else sympy .S .Zero
1080- for d in dimensions )
1081- return tuple (staggered )
1084+ processed = []
1085+ for d in dimensions :
1086+ if d in as_tuple (staggered ):
1087+ processed .append (sympy .S .One )
1088+ elif - d in as_tuple (staggered ):
1089+ processed .append (sympy .S .NegativeOne )
1090+ else :
1091+ processed .append (sympy .S .Zero )
1092+ return tuple (processed )
10821093
10831094 @classmethod
10841095 def __indices_setup__ (cls , * args , ** kwargs ):
@@ -1097,14 +1108,27 @@ def __indices_setup__(cls, *args, **kwargs):
10971108 assert len (args ) == len (dimensions )
10981109 staggered_indices = tuple (args )
10991110 else :
1100- # Staggered indices
1101- staggered_indices = (d + i * d .spacing / 2
1102- for d , i in zip (dimensions , staggered ))
1103- return tuple (dimensions ), tuple (staggered_indices ), staggered
1111+ if not staggered :
1112+ staggered_indices = (d for d in dimensions )
1113+ else :
1114+ # Staggered indices
1115+ staggered_indices = (d + i * d .spacing / 2
1116+ for d , i in zip (dimensions , staggered ))
1117+ return tuple (dimensions ), tuple (staggered_indices )
1118+
1119+ @property
1120+ def staggered (self ):
1121+ """The staggered indices of the object."""
1122+ if self ._staggered :
1123+ return Staggering (* self ._staggered , getters = self .dimensions )
1124+ else :
1125+ return Staggering (getters = self .dimensions )
11041126
11051127 @property
11061128 def is_Staggered (self ):
1107- return self .staggered is not None
1129+ if not self .staggered :
1130+ return False
1131+ return True
11081132
11091133 @classmethod
11101134 def __shape_setup__ (cls , ** kwargs ):
@@ -1392,7 +1416,6 @@ def __fd_setup__(self):
13921416 @classmethod
13931417 def __indices_setup__ (cls , * args , ** kwargs ):
13941418 dimensions = kwargs .get ('dimensions' )
1395- staggered = kwargs .get ('staggered' )
13961419
13971420 if dimensions is None :
13981421 save = kwargs .get ('save' )
@@ -1407,7 +1430,7 @@ def __indices_setup__(cls, *args, **kwargs):
14071430 dimensions .insert (cls ._time_position , time_dim )
14081431
14091432 return Function .__indices_setup__ (
1410- * args , dimensions = dimensions , staggered = staggered
1433+ * args , dimensions = dimensions , staggered = kwargs . get ( ' staggered' )
14111434 )
14121435
14131436 @classmethod
@@ -1446,7 +1469,7 @@ def __shape_setup__(cls, **kwargs):
14461469
14471470 @cached_property
14481471 def _fd_priority (self ):
1449- return 2.1 if self .staggered in [ NODE , None ] else 2.2
1472+ return 2.1 if self .staggered . on_node else 2.2
14501473
14511474 @property
14521475 def time_order (self ):
@@ -1600,7 +1623,7 @@ def __indices_setup__(cls, **kwargs):
16001623 # Sanity check
16011624 assert not any (d .is_NonlinearDerived for d in dimensions )
16021625
1603- return dimensions , dimensions , ( sympy . S . Zero for _ in dimensions )
1626+ return dimensions , dimensions
16041627
16051628 def __halo_setup__ (self , ** kwargs ):
16061629 pointer_dim = kwargs .get ('pointer_dim' )
0 commit comments