@@ -66,9 +66,6 @@ class DiscreteFunction(AbstractFunction, ArgProvider, Differentiable):
6666 __rkwargs__ = AbstractFunction .__rkwargs__ + ('staggered' , 'coefficients' )
6767
6868 def __init_finalize__ (self , * args , function = None , ** kwargs ):
69- # Staggering metadata
70- self ._staggered = self .__staggered_setup__ (** kwargs )
71-
7269 # Now that *all* __X_setup__ hooks have been called, we can let the
7370 # superclass constructor do its job
7471 super ().__init_finalize__ (* args , ** kwargs )
@@ -180,18 +177,6 @@ def __coefficients_setup__(self, **kwargs):
180177 " not %s" % (str (fd_weights_registry ), coeffs ))
181178 return coeffs
182179
183- def __staggered_setup__ (self , ** kwargs ):
184- """
185- Setup staggering-related metadata. This method assigns:
186-
187- * 0 to non-staggered dimensions;
188- * 1 to staggered dimensions.
189- """
190- staggered = kwargs .get ('staggered' , None )
191- if staggered is CELL :
192- staggered = self .dimensions
193- return staggered
194-
195180 @cached_property
196181 def _functions (self ):
197182 return {self .function }
@@ -208,10 +193,6 @@ def _mem_external(self):
208193 def _mem_heap (self ):
209194 return True
210195
211- @property
212- def staggered (self ):
213- return self ._staggered
214-
215196 @property
216197 def coefficients (self ):
217198 """Form of the coefficients of the function."""
@@ -1077,34 +1058,49 @@ def _eval_at(self, func):
10771058 return self .subs (mapper )
10781059 return self
10791060
1061+ @classmethod
1062+ def __staggered_setup__ (cls , dimensions , ** kwargs ):
1063+ """
1064+ Setup staggering-related metadata. This method assigns:
1065+
1066+ * 0 to non-staggered dimensions;
1067+ * 1 to staggered dimensions.
1068+ """
1069+ stagg = kwargs .get ('staggered' , None )
1070+ if stagg is CELL :
1071+ staggered = (sympy .S .One for d in dimensions )
1072+ elif stagg in [None , NODE ]:
1073+ staggered = (sympy .S .Zero for d in dimensions )
1074+ elif all (is_integer (s ) for s in as_tuple (stagg )):
1075+ # Staggering is already a tuple likely from rebuild
1076+ assert len (stagg ) == len (dimensions )
1077+ return tuple (stagg )
1078+ else :
1079+ staggered = (sympy .S .One if d in as_tuple (stagg ) else sympy .S .Zero
1080+ for d in dimensions )
1081+ return tuple (staggered )
1082+
10801083 @classmethod
10811084 def __indices_setup__ (cls , * args , ** kwargs ):
10821085 grid = kwargs .get ('grid' )
10831086 dimensions = kwargs .get ('dimensions' )
1087+ staggered = kwargs .get ('staggered' )
1088+
10841089 if grid is None :
10851090 if dimensions is None :
10861091 raise TypeError ("Need either `grid` or `dimensions`" )
10871092 elif dimensions is None :
10881093 dimensions = grid .dimensions
10891094
1095+ staggered = cls .__staggered_setup__ (dimensions , staggered = staggered )
10901096 if args :
10911097 assert len (args ) == len (dimensions )
1092- return tuple (dimensions ), tuple (args )
1093-
1094- # Staggered indices
1095- staggered = kwargs .get ("staggered" , None )
1096- if staggered in [None , NODE ]:
1097- staggered_indices = dimensions
1098- elif staggered == CELL :
1099- staggered_indices = [d + d .spacing / 2 for d in dimensions ]
1098+ staggered_indices = tuple (args )
11001099 else :
1101- mapper = {d : d for d in dimensions }
1102- for s in as_tuple (staggered ):
1103- c , s = s .as_coeff_Mul ()
1104- mapper .update ({s : s + c * s .spacing / 2 })
1105- staggered_indices = mapper .values ()
1106-
1107- return tuple (dimensions ), tuple (staggered_indices )
1100+ # Staggered indices
1101+ staggered_indices = (d + i * d .spacing / 2
1102+ for d , i in zip (dimensions , staggered ))
1103+ return tuple (dimensions ), tuple (staggered_indices ), staggered
11081104
11091105 @property
11101106 def is_Staggered (self ):
@@ -1604,7 +1600,7 @@ def __indices_setup__(cls, **kwargs):
16041600 # Sanity check
16051601 assert not any (d .is_NonlinearDerived for d in dimensions )
16061602
1607- return dimensions , dimensions
1603+ return dimensions , dimensions , ( sympy . S . Zero for _ in dimensions )
16081604
16091605 def __halo_setup__ (self , ** kwargs ):
16101606 pointer_dim = kwargs .get ('pointer_dim' )
0 commit comments