Skip to content

Commit 687750e

Browse files
committed
Allow for custom derivative operators in UnaryGridFunction, and remove restrictions on tensor fields.
1 parent 556dcbd commit 687750e

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

dedalus/core/operators.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -515,17 +515,16 @@ class UnaryGridFunction(NonlinearOperator, FutureField):
515515
Unary function acting on grid data. Must be vectorized
516516
and include an output array argument, e.g. func(x, out).
517517
arg : dedalus operand
518-
Argument field or operator
519-
allow_tensors : bool, optional
520-
Allow application to vectors and tensors (default: False)
518+
Argument field or operator.
519+
deriv : function, optional
520+
Symbolic derivative of func. Defaults are provided
521+
for some common numpy/scipy ufuncs (default: None).
521522
out : field, optional
522-
Output field (default: new field)
523+
Output field (default: new field).
523524
524525
Notes
525526
-----
526-
1. By default, only scalar fields are allowed as arguments. To allow
527-
application to vector and tensor fields, set allow_tensors=True.
528-
2. The supplied function must support an output argument called 'out'
527+
The supplied function must support an output argument called 'out'
529528
and act in a vectorized fashion. The action is essentially:
530529
531530
func(arg['g'], out=out['g'])
@@ -560,12 +559,13 @@ class UnaryGridFunction(NonlinearOperator, FutureField):
560559
aliases.update({ufunc.__name__: ufunc for ufunc in ufunc_derivatives})
561560
aliases.update({'abs': np.absolute, 'conj': np.conjugate})
562561

563-
def __init__(self, func, arg, allow_tensors=False, out=None):
564-
if arg.tensorsig and not allow_tensors:
565-
raise ValueError("ufuncs not defined for vector/tensor fields.")
562+
def __init__(self, func, arg, deriv=None, out=None):
566563
super().__init__(arg, out=out)
567564
self.func = func
568-
self.allow_tensors = allow_tensors
565+
if deriv is None and func in self.ufunc_derivatives:
566+
self.deriv = self.ufunc_derivatives[func]
567+
else:
568+
self.deriv = deriv
569569
# FutureField requirements
570570
self.domain = arg.domain
571571
self.tensorsig = arg.tensorsig
@@ -582,18 +582,18 @@ def _build_bases(self, arg0):
582582
return bases
583583

584584
def new_operand(self, arg):
585-
return UnaryGridFunction(self.func, arg, allow_tensors=self.allow_tensors)
585+
return UnaryGridFunction(self.func, arg, deriv=self.deriv)
586586

587587
def reinitialize(self, **kw):
588588
arg = self.args[0].reinitialize(**kw)
589589
return self.new_operand(arg)
590590

591591
def sym_diff(self, var):
592592
"""Symbolically differentiate with respect to specified operand."""
593-
if self.func not in self.ufunc_derivatives:
593+
if self.deriv is None:
594594
raise ValueError(f"Symbolic derivative not implemented for {self.func.__name__}.")
595595
arg = self.args[0]
596-
return self.ufunc_derivatives[self.func](arg) * arg.sym_diff(var)
596+
return self.deriv(arg) * arg.sym_diff(var)
597597

598598
def check_conditions(self):
599599
# Field must be in grid layout

0 commit comments

Comments
 (0)