@@ -155,11 +155,12 @@ def coefficients(self):
155155 key = lambda x : coeff_priority .get (x , - 1 )
156156 return sorted (coefficients , key = key , reverse = True )[0 ]
157157
158- def _eval_at (self , func ):
158+ def _eval_at (self , func , ** kwargs ):
159159 if not func .is_Staggered :
160160 # Cartesian grid, do no waste time
161161 return self
162- return self .func (* [getattr (a , '_eval_at' , lambda x : a )(func ) for a in self .args ])
162+ return self .func (* [getattr (a , '_eval_at' , lambda x , ** kw : a )(func , ** kwargs )
163+ for a in self .args ])
163164
164165 def _subs (self , old , new , ** hints ):
165166 if old == self :
@@ -466,7 +467,11 @@ def highest_priority(DiffOp):
466467 # set of dimensions is used when multiple ones with the same
467468 # priority appear
468469 prio = lambda x : (getattr (x , '_fd_priority' , 0 ), len (x .dimensions ))
469- return sorted (DiffOp ._args_diff , key = prio , reverse = True )[0 ]
470+ args = DiffOp ._args_diff
471+ if not args :
472+ return DiffOp
473+ else :
474+ return sorted (DiffOp ._args_diff , key = prio , reverse = True )[0 ]
470475
471476
472477class DifferentiableOp (Differentiable ):
@@ -532,7 +537,7 @@ class DifferentiableFunction(DifferentiableOp):
532537 def __new__ (cls , * args , ** kwargs ):
533538 return cls .__sympy_class__ .__new__ (cls , * args , ** kwargs )
534539
535- def _eval_at (self , func ):
540+ def _eval_at (self , func , ** kwargs ):
536541 return self
537542
538543
@@ -641,6 +646,56 @@ def _gather_for_diff(self):
641646
642647 return self .func (* new_args , evaluate = False )
643648
649+ def _eval_at (self , func , mul_first = False , ** kwargs ):
650+ # Dont evaluate mul first
651+ if not mul_first :
652+ return super ()._eval_at (func , mul_first = mul_first )
653+
654+ # Not a basic a*b*c... expression, just defer to superclass
655+ if any (isinstance (f , DifferentiableOp ) for f in self .args ):
656+ return super ()._eval_at (func , mul_first = mul_first )
657+
658+ # Split Derivative and Differentiable args
659+ derivs , other = split (self .args , lambda e : isinstance (e , sympy .Derivative ))
660+
661+ if derivs :
662+ derivs = Differentiable ._eval_at (self .func (* derivs ), func ,
663+ mul_first = mul_first )
664+ else :
665+ derivs = 1
666+
667+ if not other :
668+ return derivs
669+ elif len (other ) > 1 :
670+ expr = self .func (* other )._gather_for_diff
671+ else :
672+ expr = other [0 ]
673+
674+ # Non differentiable expr (e.g., number)
675+ if not isinstance (expr , Differentiable ):
676+ return self .func (derivs , expr )
677+
678+ # Build mapper for dimensions that need to be interpolated
679+ mapper = {}
680+ for d in self .dimensions :
681+ try :
682+ if self .indices_ref [d ] is not func .indices_ref [d ]:
683+ mapper [d ] = func .indices_ref [d ]
684+ except KeyError :
685+ pass
686+
687+ # Nothing to interpolate
688+ if not mapper :
689+ return super ()._eval_at (func , mul_first = mul_first )
690+
691+ # Interpolate expr at the required indices
692+ interp = expr .diff (* mapper .keys (), deriv_order = [0 for _ in mapper ],
693+ fd_order = [self .interp_order for _ in mapper ],
694+ x0 = mapper )
695+
696+ # Return the full expression with Derivatives
697+ return self .func (derivs , interp )
698+
644699
645700class Pow (DifferentiableOp , sympy .Pow ):
646701 _fd_priority = 0
@@ -987,7 +1042,7 @@ def _subs(self, old, new, **hints):
9871042
9881043class DiffDerivative (IndexDerivative , DifferentiableOp ):
9891044
990- def _eval_at (self , func ):
1045+ def _eval_at (self , func , ** kwargs ):
9911046 # Like EvalDerivative, a DiffDerivative must have already been evaluated
9921047 # at a valid x0 and should not be re-evaluated at a different location
9931048 return self
@@ -1038,7 +1093,7 @@ def _new_rawargs(self, *args, **kwargs):
10381093 kwargs .pop ('is_commutative' , None )
10391094 return self .func (* args , ** kwargs )
10401095
1041- def _eval_at (self , func ):
1096+ def _eval_at (self , func , ** kwargs ):
10421097 # An EvalDerivative must have already been evaluated at a valid x0
10431098 # and should not be re-evaluated at a different location
10441099 return self
0 commit comments