@@ -515,17 +515,16 @@ class UnaryGridFunction(NonlinearOperator, FutureField):
515515 Unary function acting on grid data. Must be vectorized
516516 and include an output array argument, e.g. func(x, out).
517517 arg : dedalus operand
518- Argument field or operator
519- allow_tensors : bool, optional
520- Allow application to vectors and tensors (default: False)
518+ Argument field or operator.
519+ deriv : function, optional
520+ Symbolic derivative of func. Defaults are provided
521+ for some common numpy/scipy ufuncs (default: None).
521522 out : field, optional
522- Output field (default: new field)
523+ Output field (default: new field).
523524
524525 Notes
525526 -----
526- 1. By default, only scalar fields are allowed as arguments. To allow
527- application to vector and tensor fields, set allow_tensors=True.
528- 2. The supplied function must support an output argument called 'out'
527+ The supplied function must support an output argument called 'out'
529528 and act in a vectorized fashion. The action is essentially:
530529
531530 func(arg['g'], out=out['g'])
@@ -560,12 +559,13 @@ class UnaryGridFunction(NonlinearOperator, FutureField):
560559 aliases .update ({ufunc .__name__ : ufunc for ufunc in ufunc_derivatives })
561560 aliases .update ({'abs' : np .absolute , 'conj' : np .conjugate })
562561
563- def __init__ (self , func , arg , allow_tensors = False , out = None ):
564- if arg .tensorsig and not allow_tensors :
565- raise ValueError ("ufuncs not defined for vector/tensor fields." )
562+ def __init__ (self , func , arg , deriv = None , out = None ):
566563 super ().__init__ (arg , out = out )
567564 self .func = func
568- self .allow_tensors = allow_tensors
565+ if deriv is None and func in self .ufunc_derivatives :
566+ self .deriv = self .ufunc_derivatives [func ]
567+ else :
568+ self .deriv = deriv
569569 # FutureField requirements
570570 self .domain = arg .domain
571571 self .tensorsig = arg .tensorsig
@@ -582,18 +582,18 @@ def _build_bases(self, arg0):
582582 return bases
583583
584584 def new_operand (self , arg ):
585- return UnaryGridFunction (self .func , arg , allow_tensors = self .allow_tensors )
585+ return UnaryGridFunction (self .func , arg , deriv = self .deriv )
586586
587587 def reinitialize (self , ** kw ):
588588 arg = self .args [0 ].reinitialize (** kw )
589589 return self .new_operand (arg )
590590
591591 def sym_diff (self , var ):
592592 """Symbolically differentiate with respect to specified operand."""
593- if self .func not in self . ufunc_derivatives :
593+ if self .deriv is None :
594594 raise ValueError (f"Symbolic derivative not implemented for { self .func .__name__ } ." )
595595 arg = self .args [0 ]
596- return self .ufunc_derivatives [ self . func ] (arg ) * arg .sym_diff (var )
596+ return self .deriv (arg ) * arg .sym_diff (var )
597597
598598 def check_conditions (self ):
599599 # Field must be in grid layout
0 commit comments