Skip to content

Commit 8b709a9

Browse files
committed
Add some helpers: basis.derivative_basis() method, field.T and field.H for regular/Hermitian transpose, and dist.IdentityTensor.
1 parent c153f2e commit 8b709a9

File tree

11 files changed

+45
-26
lines changed

11 files changed

+45
-26
lines changed

dedalus/core/basis.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,11 @@ def multiplication_matrix(self, subproblem, arg_basis, coeffs, ncc_comp, arg_com
556556
matrix = convert @ matrix
557557
return matrix[:N, :N]
558558

559+
def derivative_basis(self, order=1):
560+
a = self.a + order
561+
b = self.b + order
562+
return self.clone_with(a=a, b=b)
563+
559564

560565
def Legendre(*args, **kw):
561566
return Jacobi(*args, a=0, b=0, **kw)
@@ -631,9 +636,7 @@ class DifferentiateJacobi(operators.Differentiate, operators.SpectralOperator1D)
631636

632637
@staticmethod
633638
def _output_basis(input_basis):
634-
a = input_basis.a + 1
635-
b = input_basis.b + 1
636-
return input_basis.clone_with(a=a, b=b)
639+
return input_basis.derivative_basis(order=1)
637640

638641
@staticmethod
639642
@CachedMethod
@@ -1944,6 +1947,10 @@ def radius_multiplication_matrix(self, m, spintotal, order, d):
19441947
operator = R2**(d//2) @ operator
19451948
return operator(self.n_size(m), self.alpha + self.k, abs(m + spintotal)).square.astype(np.float64)
19461949

1950+
def derivative_basis(self, order=1):
1951+
k = self.k + order
1952+
return self.clone_with(k=k)
1953+
19471954

19481955
class AnnulusBasis(PolarBasis):
19491956

@@ -3924,6 +3931,10 @@ def matrix_dependence(self, matrix_coupling):
39243931
def constant(self):
39253932
return (self.Lmax==0, self.Lmax==0, False)
39263933

3934+
def derivative_basis(self, order=1):
3935+
k = self.k + order
3936+
return self.clone_with(k=k)
3937+
39273938
@CachedAttribute
39283939
def constant_mode_value(self):
39293940
# Adjust for SWSH normalization

dedalus/core/distributor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,14 @@ def TensorField(self, *args, **kw):
216216
from .field import TensorField
217217
return TensorField(self, *args, **kw)
218218

219+
def IdentityTensor(self, coordsys):
220+
"""Identity tensor field."""
221+
from .field import TensorField
222+
I = TensorField(self, (coordsys, coordsys))
223+
for i in range(coordsys.dim):
224+
I['g'][i, i] = 1
225+
return I
226+
219227
def local_grid(self, basis, scale=None):
220228
# TODO: remove from bases and do it all here?
221229
if basis.dim == 1:

dedalus/core/field.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,15 @@ def expression_matrices(self, subproblem, vars, **kw):
247247
"""Build expression matrices for a specific subproblem and variables."""
248248
raise NotImplementedError()
249249

250+
@property
251+
def T(self):
252+
from .operators import TransposeComponents
253+
return TransposeComponents(self)
250254

255+
@property
256+
def H(self):
257+
from .operators import TransposeComponents
258+
return TransposeComponents(np.conj(self))
251259

252260

253261
class Current(Operand):

dedalus/core/operators.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2813,8 +2813,7 @@ def __init__(self, operand, coordsys, out=None):
28132813

28142814
@staticmethod
28152815
def _output_basis(input_basis):
2816-
out = input_basis._new_k(input_basis.k + 1)
2817-
return out
2816+
return input_basis.derivative_basis(1)
28182817

28192818
def check_conditions(self):
28202819
"""Check that operands are in a proper layout."""
@@ -2962,8 +2961,7 @@ def __init__(self, operand, coordsys, out=None):
29622961

29632962
@staticmethod
29642963
def _output_basis(input_basis):
2965-
out = input_basis._new_k(input_basis.k + 1)
2966-
return out
2964+
return input_basis.derivative_basis(1)
29672965

29682966
def check_conditions(self):
29692967
"""Check that operands are in a proper layout."""
@@ -3164,8 +3162,7 @@ def __init__(self, operand, index=0, out=None):
31643162

31653163
@staticmethod
31663164
def _output_basis(input_basis):
3167-
out = input_basis._new_k(input_basis.k + 1)
3168-
return out
3165+
return input_basis.derivative_basis(1)
31693166

31703167
def check_conditions(self):
31713168
"""Check that operands are in a proper layout."""
@@ -3225,8 +3222,7 @@ def __init__(self, operand, index=0, out=None):
32253222

32263223
@staticmethod
32273224
def _output_basis(input_basis):
3228-
out = input_basis._new_k(input_basis.k + 1)
3229-
return out
3225+
return input_basis.derivative_basis(1)
32303226

32313227
def check_conditions(self):
32323228
"""Check that operands are in a proper layout."""
@@ -3371,8 +3367,7 @@ def __init__(self, operand, index=0, out=None):
33713367

33723368
@staticmethod
33733369
def _output_basis(input_basis):
3374-
out = input_basis._new_k(input_basis.k + 1)
3375-
return out
3370+
return input_basis.derivative_basis(1)
33763371

33773372
def check_conditions(self):
33783373
"""Check that operands are in a proper layout."""
@@ -3515,8 +3510,7 @@ def __init__(self, operand, index=0, out=None):
35153510

35163511
@staticmethod
35173512
def _output_basis(input_basis):
3518-
out = input_basis._new_k(input_basis.k + 1)
3519-
return out
3513+
return input_basis.derivative_basis(1)
35203514

35213515
def check_conditions(self):
35223516
"""Check that operands are in a proper layout."""
@@ -3718,8 +3712,7 @@ def __init__(self, operand, coordsys, out=None):
37183712

37193713
@staticmethod
37203714
def _output_basis(input_basis):
3721-
out = input_basis._new_k(input_basis.k + 2)
3722-
return out
3715+
return input_basis.derivative_basis(2)
37233716

37243717
def check_conditions(self):
37253718
"""Check that operands are in a proper layout."""
@@ -3827,8 +3820,7 @@ def __init__(self, operand, coordsys, out=None):
38273820

38283821
@staticmethod
38293822
def _output_basis(input_basis):
3830-
out = input_basis._new_k(input_basis.k + 2)
3831-
return out
3823+
return input_basis.derivative_basis(2)
38323824

38333825
def check_conditions(self):
38343826
"""Check that operands are in a proper layout."""

docs/notebooks/dedalus_tutorial_3.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@
179179
"c = -1.76\n",
180180
"\n",
181181
"# Tau polynomials\n",
182-
"tau_basis = xbasis.clone_with(a=1.5, b=1.5)\n",
182+
"tau_basis = xbasis.derivative_basis(2)\n",
183183
"p1 = dist.Field(bases=tau_basis)\n",
184184
"p2 = dist.Field(bases=tau_basis)\n",
185185
"p1['c'][-1] = 1\n",

docs/notebooks/dedalus_tutorial_4.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@
106106
"c = -1.76\n",
107107
"\n",
108108
"# Tau polynomials\n",
109-
"tau_basis = xbasis.clone_with(a=1.5, b=1.5)\n",
109+
"tau_basis = xbasis.derivative_basis(2)\n",
110110
"p1 = dist.Field(bases=tau_basis)\n",
111111
"p2 = dist.Field(bases=tau_basis)\n",
112112
"p1['c'][-1] = 1\n",

docs/pages/tau_method.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ Here we'll take :math:`P(y)` to be the highest mode in the Chebyshev-U basis, in
132132
133133
# Substitutions
134134
ex, ey = coords.unit_vector_fields(dist)
135-
lift_basis = ybasis.clone_with(a=1/2, b=1/2) # Chebyshev U basis
135+
lift_basis = ybasis.derivative_basis(1) # Chebyshev U basis
136136
lift = lambda A, n: d3.Lift(A, lift_basis, -1) # Shortcut for multiplying by U_{N-1}(y)
137137
grad_u = d3.grad(u) - ey*lift(tau_u1) # Operator representing G
138138

examples/evp_1d_waves_on_a_string/waves_on_a_string.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
# Substitutions
4343
dx = lambda A: d3.Differentiate(A, xcoord)
44-
lift_basis = xbasis.clone_with(a=1/2, b=1/2) # First derivative basis
44+
lift_basis = xbasis.derivative_basis(1)
4545
lift = lambda A: d3.Lift(A, lift_basis, -1)
4646
ux = dx(u) + lift(tau_1) # First-order reduction
4747

examples/ivp_2d_rayleigh_benard/rayleigh_benard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
nu = (Rayleigh / Prandtl)**(-1/2)
6262
x, z = dist.local_grids(xbasis, zbasis)
6363
ex, ez = coords.unit_vector_fields(dist)
64-
lift_basis = zbasis.clone_with(a=1/2, b=1/2) # First derivative basis
64+
lift_basis = zbasis.derivative_basis(1)
6565
lift = lambda A: d3.Lift(A, lift_basis, -1)
6666
grad_u = d3.grad(u) + ez*lift(tau_u1) # First-order reduction
6767
grad_b = d3.grad(b) + ez*lift(tau_b1) # First-order reduction

examples/ivp_shell_convection/shell_convection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
er['g'][2] = 1
6666
rvec = dist.VectorField(coords, bases=basis.radial_basis)
6767
rvec['g'][2] = r
68-
lift_basis = basis.clone_with(k=1) # First derivative basis
68+
lift_basis = basis.derivative_basis(1)
6969
lift = lambda A: d3.Lift(A, lift_basis, -1)
7070
grad_u = d3.grad(u) + rvec*lift(tau_u1) # First-order reduction
7171
grad_b = d3.grad(b) + rvec*lift(tau_b1) # First-order reduction

0 commit comments

Comments
 (0)