Skip to content

Commit f264159

Browse files
committed
Add normalization method to fields. Remove all old references to SphericalShell and use Shell instead.
1 parent 8a10b67 commit f264159

13 files changed

+76
-35
lines changed

dedalus/core/basis.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ def __init__(self, coord, size, bounds, dealias):
313313
self.size = size
314314
self.shape = (size,)
315315
self.bounds = bounds
316+
self.volume = bounds[1] - bounds[0]
316317
if isinstance(dealias, tuple):
317318
self.dealias = dealias
318319
else:
@@ -1960,6 +1961,7 @@ def __init__(self, coordsystem, shape, dtype, radii=(1,2), k=0, alpha=(-0.5,-0.5
19601961
radius_library = "matrix"
19611962
self.radius_library = radius_library
19621963
self.radii = tuple(radii)
1964+
self.volume = np.pi * (radii[1]**2 - radii[0]**2)
19631965
self.dR = radii[1] - radii[0]
19641966
self.rho = (radii[1] + radii[0])/self.dR
19651967
self.alpha = tuple(alpha)
@@ -2207,6 +2209,7 @@ def __init__(self, coordsystem, shape, dtype, radius=1, k=0, alpha=0, dealias=(1
22072209
radius_library = "matrix"
22082210
self.radius_library = radius_library
22092211
self.radius = radius
2212+
self.volume = np.pi * radius**2
22102213
self.alpha = alpha
22112214
self.radial_COV = AffineCOV((0, 1), (0, radius))
22122215
if self.mmax > 2*self.Nmax:
@@ -2538,6 +2541,7 @@ def __init__(self, coordsystem, shape, dtype, radius=1, dealias=(1,1), colatitud
25382541
if colatitude_library is None:
25392542
colatitude_library = "matrix"
25402543
self.radius = radius
2544+
self.volume = 4 * np.pi * radius**2
25412545
self.colatitude_library = colatitude_library
25422546
# Set Lmax for optimal load balancing
25432547
if self.dtype == np.float64:
@@ -3153,7 +3157,7 @@ def symbol(spinindex_in, spinindex_out, ell, radius):
31533157
return k_lap / radius**2
31543158

31553159

3156-
# These are common for BallRadialBasis and SphericalShellRadialBasis
3160+
# These are common for BallRadialBasis and ShellRadialBasis
31573161
class RegularityBasis(SpinRecombinationBasis, MultidimensionalBasis):
31583162

31593163
dim = 3
@@ -3414,7 +3418,7 @@ def n_slice(self, ell):
34143418
return slice(nmin, nmax+1)
34153419

34163420

3417-
class SphericalShellRadialBasis(RegularityBasis, metaclass=CachedClass):
3421+
class ShellRadialBasis(RegularityBasis, metaclass=CachedClass):
34183422

34193423
def __init__(self, coordsystem, radial_size, dtype, radii=(1,2), alpha=(-0.5,-0.5), dealias=(1,), k=0, radius_library=None):
34203424
super().__init__(coordsystem, radial_size, k=k, dealias=dealias, dtype=dtype)
@@ -3426,6 +3430,7 @@ def __init__(self, coordsystem, radial_size, dtype, radii=(1,2), alpha=(-0.5,-0.
34263430
else:
34273431
radius_library = "matrix"
34283432
self.radii = radii
3433+
self.volume = 4 / 3 * np.pi * (radii[1]**3 - radii[0]**3)
34293434
self.dR = self.radii[1] - self.radii[0]
34303435
self.rho = (self.radii[1] + self.radii[0])/self.dR
34313436
self.alpha = alpha
@@ -3439,7 +3444,7 @@ def __init__(self, coordsystem, radial_size, dtype, radii=(1,2), alpha=(-0.5,-0.
34393444
self.backward_transform_radius]
34403445

34413446
def __eq__(self, other):
3442-
if isinstance(other, SphericalShellRadialBasis):
3447+
if isinstance(other, ShellRadialBasis):
34433448
if self.coordsystem == other.coordsystem:
34443449
if self.grid_params == other.grid_params:
34453450
if self.k == other.k:
@@ -3454,7 +3459,7 @@ def __add__(self, other):
34543459
return self
34553460
if other is self:
34563461
return self
3457-
if isinstance(other, SphericalShellRadialBasis):
3462+
if isinstance(other, ShellRadialBasis):
34583463
if self.grid_params == other.grid_params:
34593464
radial_size = max(self.shape[2], other.shape[2])
34603465
k = max(self.k, other.k)
@@ -3464,7 +3469,7 @@ def __add__(self, other):
34643469
def __mul__(self, other):
34653470
if other is None:
34663471
return self
3467-
if isinstance(other, SphericalShellRadialBasis):
3472+
if isinstance(other, ShellRadialBasis):
34683473
if self.grid_params == other.grid_params:
34693474
radial_size = max(self.shape[2], other.shape[2])
34703475
k = self.k + other.k
@@ -3482,7 +3487,7 @@ def __mul__(self, other):
34823487
args['azimuth_library'] = other.azimuth_library
34833488
args['colatitude_library'] = other.colatitude_library
34843489
args['radius_library'] = self.radius_library
3485-
return SphericalShellBasis(**args)
3490+
return ShellBasis(**args)
34863491
return NotImplemented
34873492

34883493
def __matmul__(self, other):
@@ -3491,7 +3496,7 @@ def __matmul__(self, other):
34913496
def __rmatmul__(self, other):
34923497
if other is None:
34933498
return self
3494-
if isinstance(other, SphericalShellRadialBasis):
3499+
if isinstance(other, ShellRadialBasis):
34953500
if self.grid_params == other.grid_params:
34963501
radial_size = max(self.shape[2], other.shape[2])
34973502
k = self.k + other.k
@@ -3650,6 +3655,7 @@ def __init__(self, coordsystem, radial_size, dtype, radius=1, k=0, alpha=0, deal
36503655
if radius_library is None:
36513656
radius_library = "matrix"
36523657
self.radius = radius
3658+
self.volume = 4 / 3 * np.pi * radius**3
36533659
self.alpha = alpha
36543660
self.radial_COV = AffineCOV((0, 1), (0, radius))
36553661
self.radius_library = radius_library
@@ -4052,15 +4058,16 @@ def multiplication_matrix(self, subproblem, arg_basis, coeffs, *args, **kw):
40524058
raise ValueError("Cannot build NCCs of non-radial fields.")
40534059

40544060

4055-
class SphericalShellBasis(Spherical3DBasis, metaclass=CachedClass):
4061+
class ShellBasis(Spherical3DBasis, metaclass=CachedClass):
40564062

40574063
def __init__(self, coordsystem, shape, dtype, radii=(1,2), alpha=(-0.5,-0.5), dealias=(1,1,1), k=0, azimuth_library=None, colatitude_library=None, radius_library=None):
40584064
if np.isscalar(dealias):
40594065
dealias = (dealias, dealias, dealias)
40604066
self.alpha = alpha
40614067
self.radii = radii
4068+
self.volume = 4 / 3 * np.pi * (radii[1]**3 - radii[0]**3)
40624069
self.radius_library = radius_library
4063-
self.radial_basis = SphericalShellRadialBasis(coordsystem, shape[2], radii=radii, alpha=alpha, dealias=(dealias[2],), k=k, dtype=dtype, radius_library=radius_library)
4070+
self.radial_basis = ShellRadialBasis(coordsystem, shape[2], radii=radii, alpha=alpha, dealias=(dealias[2],), k=k, dtype=dtype, radius_library=radius_library)
40644071
Spherical3DBasis.__init__(self, coordsystem, shape[:2], dealias[:2], self.radial_basis, dtype=dtype, azimuth_library=azimuth_library, colatitude_library=colatitude_library)
40654072
self.grid_params = (coordsystem, radii, alpha, dealias)
40664073
# self.forward_transform_radius = self.radial_basis.forward_transform
@@ -4073,7 +4080,7 @@ def __init__(self, coordsystem, shape, dtype, radii=(1,2), alpha=(-0.5,-0.5), de
40734080
self.backward_transform_radius]
40744081

40754082
def __eq__(self, other):
4076-
if isinstance(other, SphericalShellBasis):
4083+
if isinstance(other, ShellBasis):
40774084
if self.coordsystem == other.coordsystem:
40784085
if self.grid_params == other.grid_params:
40794086
if self.k == other.k:
@@ -4088,40 +4095,40 @@ def __add__(self, other):
40884095
return self
40894096
if other is self:
40904097
return self
4091-
if isinstance(other, SphericalShellBasis):
4098+
if isinstance(other, ShellBasis):
40924099
if self.grid_params == other.grid_params:
40934100
shape = tuple(np.maximum(self.shape, other.shape))
40944101
k = max(self.k, other.k)
4095-
return SphericalShellBasis(self.coordsystem, shape, radii=self.radial_basis.radii, alpha=self.radial_basis.alpha, dealias=self.dealias, k=k,
4102+
return ShellBasis(self.coordsystem, shape, radii=self.radial_basis.radii, alpha=self.radial_basis.alpha, dealias=self.dealias, k=k,
40964103
dtype=self.dtype, azimuth_library=self.azimuth_library, colatitude_library=self.colatitude_library,
40974104
radius_library=self.radial_basis.radius_library)
40984105
return NotImplemented
40994106

41004107
def __mul__(self, other):
41014108
if other is None:
41024109
return self
4103-
if isinstance(other, SphericalShellBasis):
4110+
if isinstance(other, ShellBasis):
41044111
if self.grid_params == other.grid_params:
41054112
shape = tuple(np.maximum(self.shape, other.shape))
41064113
k = 0
4107-
return SphericalShellBasis(self.coordsystem, shape, radii=self.radial_basis.radii, alpha=self.radial_basis.alpha, dealias=self.dealias, k=k,
4114+
return ShellBasis(self.coordsystem, shape, radii=self.radial_basis.radii, alpha=self.radial_basis.alpha, dealias=self.dealias, k=k,
41084115
dtype=self.dtype, azimuth_library=self.azimuth_library, colatitude_library=self.colatitude_library,
41094116
radius_library=self.radial_basis.radius_library)
4110-
if isinstance(other, SphericalShellRadialBasis):
4117+
if isinstance(other, ShellRadialBasis):
41114118
radial_basis = other * self.radial_basis
41124119
return self._new_k(radial_basis.k)
41134120
return NotImplemented
41144121

41154122
def __rmatmul__(self, other):
41164123
if other is None:
41174124
return self
4118-
if isinstance(other, SphericalShellRadialBasis):
4125+
if isinstance(other, ShellRadialBasis):
41194126
radial_basis = other @ self.radial_basis
41204127
return self._new_k(radial_basis.k)
41214128
return NotImplemented
41224129

41234130
def _new_k(self, k):
4124-
return SphericalShellBasis(self.coordsystem, self.shape, radii=self.radial_basis.radii, alpha=self.radial_basis.alpha, dealias=self.dealias, k=k,
4131+
return ShellBasis(self.coordsystem, self.shape, radii=self.radial_basis.radii, alpha=self.radial_basis.alpha, dealias=self.dealias, k=k,
41254132
dtype=self.dtype, azimuth_library=self.azimuth_library, colatitude_library=self.colatitude_library,
41264133
radius_library=self.radial_basis.radius_library)
41274134

@@ -4162,9 +4169,6 @@ def backward_transform_radius(self, field, axis, cdata, gdata):
41624169
gdata *= radial_basis.radial_transform_factor(field.scales[axis], data_axis, self.k)
41634170

41644171

4165-
ShellBasis = SphericalShellBasis
4166-
4167-
41684172
class BallBasis(Spherical3DBasis, metaclass=CachedClass):
41694173

41704174
transforms = {}
@@ -4174,6 +4178,7 @@ def __init__(self, coordsystem, shape, dtype, radius=1, k=0, alpha=0, dealias=(1
41744178
dealias = (dealias, dealias, dealias)
41754179
self.alpha = alpha
41764180
self.radius = radius
4181+
self.volume = 4 / 3 * np.pi * radius**3
41774182
self.radius_library = radius_library
41784183
self.radial_basis = BallRadialBasis(coordsystem, shape[2], radius=radius, k=k, alpha=alpha, dealias=(dealias[2],), dtype=dtype, radius_library=radius_library)
41794184
Spherical3DBasis.__init__(self, coordsystem, shape[:2], dealias[:2], self.radial_basis, dtype=dtype, azimuth_library=azimuth_library, colatitude_library=colatitude_library)
@@ -4347,7 +4352,7 @@ def radial_matrix(self, regindex_in, regindex_out, ell):
43474352

43484353
class ConvertConstantShell(operators.ConvertConstant, operators.SphericalEllOperator):
43494354

4350-
output_basis_type = (ShellBasis, SphericalShellRadialBasis)
4355+
output_basis_type = (ShellBasis, ShellRadialBasis)
43514356
subaxis_dependence = [False, True, True]
43524357
subaxis_coupling = [False, False, True]
43534358

@@ -4657,7 +4662,7 @@ def radial_matrix(self, regindex_in, regindex_out, m):
46574662
class LiftShell(operators.Lift, operators.SphericalEllOperator):
46584663

46594664
input_basis_type = SphereBasis
4660-
output_basis_type = SphericalShellBasis
4665+
output_basis_type = ShellBasis
46614666

46624667
def regindex_out(self, regindex_in):
46634668
return (regindex_in,)
@@ -5252,9 +5257,9 @@ def _radial_matrix(basis, ell, regtotal, position):
52525257
return reshape_vector(basis.interpolation(ell, regtotal, position), dim=2, axis=1)
52535258

52545259

5255-
class SphericalShellRadialInterpolate(operators.Interpolate, operators.SphericalEllOperator):
5260+
class ShellRadialInterpolate(operators.Interpolate, operators.SphericalEllOperator):
52565261

5257-
basis_type = SphericalShellBasis
5262+
basis_type = ShellBasis
52585263
basis_subaxis = 2
52595264

52605265
def __init__(self, operand, coord, position, out=None):

dedalus/core/domain.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
from collections import OrderedDict
88

9+
from ..tools.array import prod
910
from ..tools.cache import CachedMethod, CachedClass, CachedAttribute
1011
from ..tools.general import unify_attributes, unify, OrderedSet
1112
from .coords import Coordinate, CartesianCoordinates
@@ -45,6 +46,10 @@ def __init__(self, dist, bases):
4546
self.bases = bases # Preprocessed to remove Nones and duplicates
4647
self.dim = sum(basis.dim for basis in self.bases)
4748

49+
@CachedAttribute
50+
def volume(self):
51+
return prod([basis.volume for basis in self.bases])
52+
4853
@CachedAttribute
4954
def bases_by_axis(self):
5055
bases_by_axis = OrderedDict()

dedalus/core/field.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,37 @@ def allreduce_data_norm(self, layout=None, order=2):
731731
def allreduce_data_max(self, layout=None):
732732
return self.allreduce_data_norm(layout=layout, order=np.inf)
733733

734+
def allreduce_L2_norm(self, normalize_volume=True):
735+
from . import arithmetic
736+
from . import operators
737+
# Compute local self inner product
738+
rank = len(self.tensorsig)
739+
if rank == 0:
740+
self_inner_product = np.conj(self) * self
741+
elif rank == 1:
742+
self_inner_product = arithmetic.dot(np.conj(self), self)
743+
elif rank == 2:
744+
self_inner_product = arithmetic.Trace(arithmetic.Dot(operators.Transpose(np.conj(self)), self))
745+
else:
746+
raise ValueError("Norms only implemented up to rank-2 tensors.")
747+
# Compute L2 norm
748+
norm_sq = operators.Integrate(self_inner_product).evaluate().allreduce_data_max()
749+
if normalize_volume:
750+
norm_sq /= self.domain.volume
751+
return norm_sq ** 0.5
752+
753+
def normalize(self, normalize_volume=True):
754+
"""
755+
Normalize field inplace using L2 norm.
756+
757+
Parameters
758+
----------
759+
normalize_volume : bool, optional
760+
Normalize inner product by domain volume. Default: True.
761+
"""
762+
norm = self.allreduce_L2_norm(normalize_volume=normalize_volume)
763+
self.data /= norm
764+
734765
def broadcast_ghosts(self, output_nonconst_dims):
735766
"""Copy data over constant distributed dimensions for arithmetic broadcasting."""
736767
# Determine deployment dimensions

dedalus/tests/test_grid_operators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def build_ball(Nphi, Ntheta, Nr, dtype, dealias, radius=1):
4343
def build_shell(Nphi, Ntheta, Nr, dtype, dealias, radii=(0.5,1)):
4444
c = coords.SphericalCoordinates('phi', 'theta', 'r')
4545
d = distributor.Distributor((c,))
46-
b = basis.SphericalShellBasis(c, (Nphi, Ntheta, Nr), radii=radii, dtype=dtype, dealias=(dealias, dealias, dealias))
46+
b = basis.ShellBasis(c, (Nphi, Ntheta, Nr), radii=radii, dtype=dtype, dealias=(dealias, dealias, dealias))
4747
phi, theta, r = b.local_grids()
4848
x, y, z = c.cartesian(phi, theta, r)
4949
return c, d, b, phi, theta, r, x, y, z

dedalus/tests/test_ivp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def build_shell(Nphi, Ntheta, Nr, radii_shell, dealias, dtype, grid_scale=1):
2525
c = coords.SphericalCoordinates('phi', 'theta', 'r')
2626
d = distributor.Distributor((c,))
2727
dealias_tuple = (dealias, dealias, dealias)
28-
b = basis.SphericalShellBasis(c, (Nphi, Ntheta, Nr), radii=radii_shell, dealias=dealias_tuple, dtype=dtype)
28+
b = basis.ShellBasis(c, (Nphi, Ntheta, Nr), radii=radii_shell, dealias=dealias_tuple, dtype=dtype)
2929
grid_scale_tuple = (grid_scale, grid_scale, grid_scale)
3030
phi, theta, r = b.local_grids(grid_scale_tuple)
3131
x, y, z = c.cartesian(phi, theta, r)

dedalus/tests/test_lbvp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def test_heat_ball_cart(Nmax, Lmax, dtype):
283283
def build_shell(Nphi, Ntheta, Nr, dealias, dtype):
284284
c = coords.SphericalCoordinates('phi', 'theta', 'r')
285285
d = distributor.Distributor((c,))
286-
b = basis.SphericalShellBasis(c, (Nphi, Ntheta, Nr), radii=radii_shell, dealias=(dealias, dealias, dealias), dtype=dtype)
286+
b = basis.ShellBasis(c, (Nphi, Ntheta, Nr), radii=radii_shell, dealias=(dealias, dealias, dealias), dtype=dtype)
287287
phi, theta, r = b.local_grids()
288288
x, y, z = c.cartesian(phi, theta, r)
289289
return c, d, b, phi, theta, r, x, y, z

dedalus/tests/test_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def build_shell(Nphi, Ntheta, Nr, k, dealias, dtype):
7171
radii_shell = (0.5, 1.5)
7272
c = coords.SphericalCoordinates('phi', 'theta', 'r')
7373
d = distributor.Distributor((c,))
74-
b = basis.SphericalShellBasis(c, (Nphi, Ntheta, Nr), radii=radii_shell, k=k, dealias=(dealias, dealias, dealias), dtype=dtype)
74+
b = basis.ShellBasis(c, (Nphi, Ntheta, Nr), radii=radii_shell, k=k, dealias=(dealias, dealias, dealias), dtype=dtype)
7575
phi, theta, r = b.local_grids(b.domain.dealias)
7676
x, y, z = c.cartesian(phi, theta, r)
7777
return c, d, b, phi, theta, r, x, y, z

dedalus/tests/test_spherical3D_arithmetic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def build_ball(Nphi, Ntheta, Nr, dealias, dtype):
2828
def build_shell(Nphi, Ntheta, Nr, dealias, dtype):
2929
c = coords.SphericalCoordinates('phi', 'theta', 'r')
3030
d = distributor.Distributor((c,))
31-
b = basis.SphericalShellBasis(c, (Nphi, Ntheta, Nr), radii=radii_shell, dealias=(dealias, dealias, dealias), dtype=dtype)
31+
b = basis.ShellBasis(c, (Nphi, Ntheta, Nr), radii=radii_shell, dealias=(dealias, dealias, dealias), dtype=dtype)
3232
phi, theta, r = b.local_grids()
3333
x, y, z = c.cartesian(phi, theta, r)
3434
return c, d, b, phi, theta, r, x, y, z

dedalus/tests/test_spherical3D_calculus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def build_ball(Nphi, Ntheta, Nr, dealias, dtype):
2828
def build_shell(Nphi, Ntheta, Nr, dealias, dtype):
2929
c = coords.SphericalCoordinates('phi', 'theta', 'r')
3030
d = distributor.Distributor((c,))
31-
b = basis.SphericalShellBasis(c, (Nphi, Ntheta, Nr), radii=radii_shell, dealias=(dealias, dealias, dealias), dtype=dtype)
31+
b = basis.ShellBasis(c, (Nphi, Ntheta, Nr), radii=radii_shell, dealias=(dealias, dealias, dealias), dtype=dtype)
3232
phi, theta, r = b.local_grids(b.domain.dealias)
3333
x, y, z = c.cartesian(phi, theta, r)
3434
return c, d, b, phi, theta, r, x, y, z

dedalus/tests/test_spherical3D_operators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def build_ball(Nphi, Ntheta, Nr, k, dealias, dtype):
3030
def build_shell(Nphi, Ntheta, Nr, k, dealias, dtype):
3131
c = coords.SphericalCoordinates('phi', 'theta', 'r')
3232
d = distributor.Distributor((c,))
33-
b = basis.SphericalShellBasis(c, (Nphi, Ntheta, Nr), radii=radii_shell, k=k, dealias=(dealias, dealias, dealias), dtype=dtype)
33+
b = basis.ShellBasis(c, (Nphi, Ntheta, Nr), radii=radii_shell, k=k, dealias=(dealias, dealias, dealias), dtype=dtype)
3434
phi, theta, r = b.local_grids(b.domain.dealias)
3535
x, y, z = c.cartesian(phi, theta, r)
3636
return c, d, b, phi, theta, r, x, y, z

0 commit comments

Comments
 (0)