Skip to content

Commit 2dbd9de

Browse files
authored
Merge pull request #590 from cbegeman/add-bsf-function
Port barotropic streamfunction from MPAS-Analysis
2 parents 3e1cab5 + 3b69401 commit 2dbd9de

File tree

6 files changed

+197
-15
lines changed

6 files changed

+197
-15
lines changed

conda_package/docs/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ Ocean Tools
225225
depth.compute_depth
226226
depth.compute_zmid
227227

228+
compute_barotropic_streamfunction
229+
228230
.. currentmodule:: mpas_tools.ocean.inject_bathymetry
229231

230232
.. autosummary::

conda_package/docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ analyzing simulations, and in other MPAS-related workflows.
5050
ocean/coastline_alteration
5151
ocean/moc
5252
ocean/depth
53+
ocean/streamfunction
5354
ocean/visualization
5455

5556
.. toctree::
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
.. _ocean_streamfunction:
2+
3+
Computing streamfunctions
4+
=========================
5+
6+
Computing the barotropic streamfunction
7+
---------------------------------------
8+
9+
The function :py:func:`mpas_tools.ocean.compute_barotropic_streamfunction()`
10+
computes the barotproic streamfunction at vertices on the MPAS-Ocean grid.
11+
The function takes a dataset containing an MPAS-Ocean mesh and another with
12+
``normalVelocity`` and ``layerThickness`` variables (possibly with a
13+
``timeMonthly_avg_`` prefix). The streamfunction is computed only over the
14+
range of (positive-down) depths provided and at the given time index.
Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
1-
from mpas_tools.ocean.build_mesh import build_spherical_mesh, \
2-
build_planar_mesh
1+
from mpas_tools.ocean.build_mesh import (
2+
build_spherical_mesh,
3+
build_planar_mesh,
4+
)
5+
from mpas_tools.ocean.barotropic_streamfunction import (
6+
compute_barotropic_streamfunction,
7+
)
38
from mpas_tools.ocean.inject_bathymetry import inject_bathymetry
4-
from mpas_tools.ocean.inject_meshDensity import inject_meshDensity_from_file, \
5-
inject_spherical_meshDensity, inject_planar_meshDensity
6-
from mpas_tools.ocean.inject_preserve_floodplain import \
7-
inject_preserve_floodplain
9+
from mpas_tools.ocean.inject_meshDensity import (
10+
inject_meshDensity_from_file,
11+
inject_spherical_meshDensity,
12+
inject_planar_meshDensity,
13+
)
14+
from mpas_tools.ocean.inject_preserve_floodplain import (
15+
inject_preserve_floodplain,
16+
)
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
import xarray as xr
2+
import numpy as np
3+
import scipy.sparse
4+
import scipy.sparse.linalg
5+
6+
import logging
7+
import sys
8+
from mpas_tools.ocean.depth import compute_zmid
9+
10+
11+
def compute_barotropic_streamfunction(ds_mesh, ds, logger=None,
12+
min_depth=-5., max_depth=1.e4,
13+
prefix='timeMonthly_avg_',
14+
time_index=0):
15+
"""
16+
Compute barotropic streamfunction. Returns BSF in Sv on vertices.
17+
18+
Parameters
19+
----------
20+
ds_mesh : ``xarray.Dataset``
21+
A dataset containing MPAS mesh variables
22+
23+
ds : ``xarray.Dataset``
24+
A dataset containing MPAS output variables ``normalVelocity`` and
25+
``layerThickness`` (possibly with a ``prefix``)
26+
27+
logger : ``logging.Logger``, optional
28+
A logger for the output if not stdout
29+
30+
min_depth : float, optional
31+
The minimum depth (positive down) to compute transport over
32+
33+
max_depth : float, optional
34+
The maximum depth (positive down) to compute transport over
35+
36+
prefix : str, optional
37+
The prefix on the ``normalVelocity`` and ``layerThickness`` variables
38+
39+
time_index : int, optional
40+
The time at which to index ``ds`` (if it has ``Time`` as a dimension)
41+
"""
42+
43+
useStdout = logger is None
44+
if useStdout:
45+
logger = logging.getLogger()
46+
logger.addHandler(logging.StreamHandler(sys.stdout))
47+
logger.setLevel(logging.INFO)
48+
49+
inner_edges, transport = _compute_transport(
50+
ds_mesh, ds, min_depth=min_depth, max_depth=max_depth, prefix=prefix,
51+
time_index=time_index)
52+
logger.info('transport computed.')
53+
54+
nvertices = ds_mesh.sizes['nVertices']
55+
56+
cells_on_vertex = ds_mesh.cellsOnVertex - 1
57+
vertices_on_edge = ds_mesh.verticesOnEdge - 1
58+
is_boundary_cov = cells_on_vertex == -1
59+
boundary_vertices = np.logical_or(is_boundary_cov.isel(vertexDegree=0),
60+
is_boundary_cov.isel(vertexDegree=1))
61+
boundary_vertices = np.logical_or(boundary_vertices,
62+
is_boundary_cov.isel(vertexDegree=2))
63+
64+
# convert from boolean mask to indices
65+
boundary_vertices = np.flatnonzero(boundary_vertices.values)
66+
67+
n_boundary_vertices = len(boundary_vertices)
68+
n_inner_edges = len(inner_edges)
69+
70+
indices = np.zeros((2, 2 * n_inner_edges + n_boundary_vertices), dtype=int)
71+
data = np.zeros(2 * n_inner_edges + n_boundary_vertices, dtype=float)
72+
73+
# The difference between the streamfunction at vertices on an inner
74+
# edge should be equal to the transport
75+
v0 = vertices_on_edge.isel(nEdges=inner_edges, TWO=0).values
76+
v1 = vertices_on_edge.isel(nEdges=inner_edges, TWO=1).values
77+
78+
ind = np.arange(n_inner_edges)
79+
indices[0, 2 * ind] = ind
80+
indices[1, 2 * ind] = v1
81+
data[2 * ind] = 1.
82+
83+
indices[0, 2 * ind + 1] = ind
84+
indices[1, 2 * ind + 1] = v0
85+
data[2 * ind + 1] = -1.
86+
87+
# the streamfunction should be zero at all boundary vertices
88+
ind = np.arange(n_boundary_vertices)
89+
indices[0, 2 * n_inner_edges + ind] = n_inner_edges + ind
90+
indices[1, 2 * n_inner_edges + ind] = boundary_vertices
91+
data[2 * n_inner_edges + ind] = 1.
92+
93+
rhs = np.zeros(n_inner_edges + n_boundary_vertices, dtype=float)
94+
95+
# convert to Sv
96+
ind = np.arange(n_inner_edges)
97+
rhs[ind] = 1e-6 * transport
98+
99+
ind = np.arange(n_boundary_vertices)
100+
rhs[n_inner_edges + ind] = 0.
101+
102+
matrix = scipy.sparse.csr_matrix(
103+
(data, indices),
104+
shape=(n_inner_edges + n_boundary_vertices, nvertices))
105+
106+
solution = scipy.sparse.linalg.lsqr(matrix, rhs)
107+
bsf_vertex = xr.DataArray(-solution[0],
108+
dims=('nVertices',))
109+
110+
return bsf_vertex
111+
112+
def _compute_transport(ds_mesh, ds, min_depth, max_depth, prefix,
113+
time_index):
114+
115+
cells_on_edge = ds_mesh.cellsOnEdge - 1
116+
inner_edges = np.logical_and(cells_on_edge.isel(TWO=0) >= 0,
117+
cells_on_edge.isel(TWO=1) >= 0)
118+
119+
if 'Time' in ds.dims:
120+
ds = ds.isel(Time=time_index)
121+
122+
# convert from boolean mask to indices
123+
inner_edges = np.flatnonzero(inner_edges.values)
124+
125+
cell0 = cells_on_edge.isel(nEdges=inner_edges, TWO=0)
126+
cell1 = cells_on_edge.isel(nEdges=inner_edges, TWO=1)
127+
128+
normal_velocity = \
129+
ds[f'{prefix}normalVelocity'].isel(nEdges=inner_edges)
130+
layer_thickness = ds[f'{prefix}layerThickness']
131+
layer_thickness_edge = 0.5 * (layer_thickness.isel(nCells=cell0) +
132+
layer_thickness.isel(nCells=cell1))
133+
134+
n_vert_levels = ds.sizes['nVertLevels']
135+
136+
vert_index = xr.DataArray.from_dict(
137+
{'dims': ('nVertLevels',), 'data': np.arange(n_vert_levels)})
138+
mask_bottom = (vert_index < ds_mesh.maxLevelCell).T
139+
mask_bottom_edge = 0.5 * (mask_bottom.isel(nCells=cell0) +
140+
mask_bottom.isel(nCells=cell1))
141+
142+
if 'zMid' not in ds.keys():
143+
z_mid = compute_zmid(ds_mesh.bottomDepth, ds_mesh.maxLevelCell,
144+
ds_mesh.layerThickness)
145+
else:
146+
z_mid = ds.zMid
147+
z_mid_edge = 0.5 * (z_mid.isel(nCells=cell0) +
148+
z_mid.isel(nCells=cell1))
149+
150+
mask = np.logical_and(np.logical_and(z_mid_edge >= -max_depth,
151+
z_mid_edge <= -min_depth),
152+
mask_bottom_edge)
153+
normal_velocity = normal_velocity.where(mask)
154+
layer_thickness_edge = layer_thickness_edge.where(mask)
155+
transport = ds_mesh.dvEdge[inner_edges] * \
156+
(layer_thickness_edge * normal_velocity).sum(dim='nVertLevels')
157+
158+
return inner_edges, transport

conda_package/mpas_tools/ocean/depth.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def compute_depth(refBottomDepth):
3737
depth_bnds[0, 0] = 0.
3838
depth_bnds[1:, 0] = refBottomDepth[0:-1]
3939
depth_bnds[:, 1] = refBottomDepth
40-
depth = 0.5*(depth_bnds[:, 0] + depth_bnds[:, 1])
40+
depth = 0.5 * (depth_bnds[:, 0] + depth_bnds[:, 1])
4141

4242
return depth, depth_bnds
4343

@@ -82,11 +82,11 @@ def compute_zmid(bottomDepth, maxLevelCell, layerThickness,
8282

8383
thicknessSum = layerThickness.sum(dim=depth_dim)
8484
thicknessCumSum = layerThickness.cumsum(dim=depth_dim)
85-
zSurface = -bottomDepth+thicknessSum
85+
zSurface = -bottomDepth + thicknessSum
8686

8787
zLayerBot = zSurface - thicknessCumSum
8888

89-
zMid = zLayerBot + 0.5*layerThickness
89+
zMid = zLayerBot + 0.5 * layerThickness
9090

9191
zMid = zMid.where(vertIndex < maxLevelCell)
9292
if 'Time' in zMid.dims:
@@ -150,8 +150,7 @@ def add_depth(inFileName, outFileName, coordFileName=None):
150150
history = '{}: {}'.format(time, ' '.join(sys.argv))
151151

152152
if 'history' in ds.attrs:
153-
ds.attrs['history'] = '{}\n{}'.format(history,
154-
ds.attrs['history'])
153+
ds.attrs['history'] = f'{history}\n{ds.attrs["history"]}'
155154
else:
156155
ds.attrs['history'] = history
157156

@@ -229,8 +228,7 @@ def add_zmid(inFileName, outFileName, coordFileName=None):
229228
history = '{}: {}'.format(time, ' '.join(sys.argv))
230229

231230
if 'history' in ds.attrs:
232-
ds.attrs['history'] = '{}\n{}'.format(history,
233-
ds.attrs['history'])
231+
ds.attrs['history'] = f'{history}\n{ds.attrs["history"]}'
234232
else:
235233
ds.attrs['history'] = history
236234

@@ -288,8 +286,8 @@ def write_time_varying_zmid(inFileName, outFileName, coordFileName=None,
288286

289287
dsIn = xarray.open_dataset(inFileName)
290288
dsIn = dsIn.rename({'nVertLevels': 'depth'})
291-
inVarName = '{}layerThickness'.format(prefix)
292-
outVarName = '{}zMid'.format(prefix)
289+
inVarName = f'{prefix}layerThickness'
290+
outVarName = f'{prefix}zMid'
293291
layerThickness = dsIn[inVarName]
294292

295293
zMid = compute_zmid(dsCoord.bottomDepth, dsCoord.maxLevelCell,

0 commit comments

Comments
 (0)