Skip to content

Commit e04edff

Browse files
committed
small cleanup
1 parent 9b0bfdd commit e04edff

File tree

3 files changed

+89
-91
lines changed

3 files changed

+89
-91
lines changed

xmitgcm/llcreader/known_models.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -267,17 +267,3 @@ def __init__(self):
267267
shrunk=True, join_char='/')
268268

269269
super(SverdrupASTE270Model, self).__init__(store)
270-
271-
#class SverdrupSmoothLLC90Model(SmoothLLC90Model):
272-
# @_requires_sverdrup
273-
# def __init__(self):
274-
# from fsspec.implementations.local import LocalFileSystem
275-
# fs = LocalFileSystem()
276-
# base_path = '/scratch2/shared/aste-release1/diags'
277-
# grid_path = '/scratch2/shared/aste-release1/grid'
278-
# mask_path = '/scratch2/shared/aste-release1/masks.zarr'
279-
# store = stores.NestedStore(fs, base_path=base_path, grid_path=grid_path,
280-
# mask_path=mask_path,
281-
# shrunk=True, join_char='/')
282-
#
283-
# super(SverdrupASTE270Model, self).__init__(store)

xmitgcm/llcreader/llcmodel.py

Lines changed: 81 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def _get_grid_metadata():
3737

3838
return grid_metadata
3939

40-
def _get_var_metadata():
40+
def _get_var_metadata(extra_variables=None):
4141
# The LLC run data comes with zero metadata. So we import metadata from
4242
# the xmitgcm package.
4343
from ..variables import state_variables, package_state_variables
@@ -50,16 +50,8 @@ def _get_var_metadata():
5050
var_metadata = state_variables.copy()
5151
var_metadata.update(package_state_variables)
5252
var_metadata.update(available_diags)
53-
extra_variables = {
54-
f'smooth3Dfld001': {
55-
'dims': ['k', 'j', 'i'],
56-
'attrs': {
57-
'standard_name': 'smooth_fld',
58-
'long_name': r'$C\mathbf{z}$',
59-
}
60-
},
61-
}
62-
var_metadata.update(extra_variables)
53+
if extra_variables is not None:
54+
var_metadata.update(extra_variables)
6355

6456
# even the file names from the LLC data differ from standard MITgcm output
6557
aliases = {'Eta': 'ETAN', 'PhiBot': 'PHIBOT', 'Salt': 'SALT',
@@ -72,54 +64,7 @@ def _get_var_metadata():
7264

7365
return var_metadata
7466

75-
_VAR_METADATA = _get_var_metadata()
76-
77-
def _is_vgrid(vname):
78-
# check for 1d, vertical grid variables
79-
dims = _VAR_METADATA[vname]['dims']
80-
return len(dims)==1 and dims[0][0]=='k'
81-
82-
def _get_variable_point(vname, mask_override):
83-
# fix for https://github.com/MITgcm/xmitgcm/issues/191
84-
if vname in mask_override:
85-
return mask_override[vname]
86-
dims = _VAR_METADATA[vname]['dims']
87-
if 'i' in dims and 'j' in dims:
88-
point = 'c'
89-
elif 'i_g' in dims and 'j' in dims:
90-
point = 'w'
91-
elif 'i' in dims and 'j_g' in dims:
92-
point = 's'
93-
elif 'i_g' in dims and 'j_g' in dims:
94-
raise ValueError("Don't have masks for corner points!")
95-
else:
96-
raise ValueError("Variable `%s` is not a horizontal variable." % vname)
97-
return point
98-
99-
def _get_scalars_and_vectors(varnames, type):
100-
101-
for vname in varnames:
102-
if vname not in _VAR_METADATA:
103-
raise ValueError("Varname `%s` not found in metadata." % vname)
104-
105-
if type != 'latlon':
106-
return varnames, []
10767

108-
scalars = []
109-
vector_pairs = []
110-
for vname in varnames:
111-
meta = _VAR_METADATA[vname]
112-
try:
113-
mate = meta['attrs']['mate']
114-
if mate not in varnames:
115-
raise ValueError("Vector pairs are required to create "
116-
"latlon type datasets. Varname `%s` is "
117-
"missing its vector mate `%s`"
118-
% vname, mate)
119-
vector_pairs.append((vname, mate))
120-
varnames.remove(mate)
121-
except KeyError:
122-
scalars.append(vname)
12368

12469
def _decompress(data, mask, dtype):
12570
data_blank = np.full_like(mask, np.nan, dtype=dtype)
@@ -604,6 +549,7 @@ class BaseLLCModel:
604549
varnames = []
605550
grid_varnames = []
606551
mask_override = {}
552+
var_metadata = None
607553
domain = 'global'
608554
pad_before = [0]*_nfacets
609555
pad_after = [0]*_nfacets
@@ -642,6 +588,53 @@ def _dtype(self,varname=None):
642588
elif isinstance(self.dtype,dict):
643589
return np.dtype(self.dtype[varname])
644590

591+
def _is_vgrid(self, vname):
592+
# check for 1d, vertical grid variables
593+
dims = self.var_metadata[vname]['dims']
594+
return len(dims)==1 and dims[0][0]=='k'
595+
596+
def _get_variable_point(self, vname, mask_override):
597+
# fix for https://github.com/MITgcm/xmitgcm/issues/191
598+
if vname in mask_override:
599+
return mask_override[vname]
600+
dims = self.var_metadata[vname]['dims']
601+
if 'i' in dims and 'j' in dims:
602+
point = 'c'
603+
elif 'i_g' in dims and 'j' in dims:
604+
point = 'w'
605+
elif 'i' in dims and 'j_g' in dims:
606+
point = 's'
607+
elif 'i_g' in dims and 'j_g' in dims:
608+
raise ValueError("Don't have masks for corner points!")
609+
else:
610+
raise ValueError("Variable `%s` is not a horizontal variable." % vname)
611+
return point
612+
613+
def _get_scalars_and_vectors(self, varnames, type):
614+
615+
for vname in varnames:
616+
if vname not in self.var_metadata:
617+
raise ValueError("Varname `%s` not found in metadata." % vname)
618+
619+
if type != 'latlon':
620+
return varnames, []
621+
622+
scalars = []
623+
vector_pairs = []
624+
for vname in varnames:
625+
meta = self.var_metadata[vname]
626+
try:
627+
mate = meta['attrs']['mate']
628+
if mate not in varnames:
629+
raise ValueError("Vector pairs are required to create "
630+
"latlon type datasets. Varname `%s` is "
631+
"missing its vector mate `%s`"
632+
% vname, mate)
633+
vector_pairs.append((vname, mate))
634+
varnames.remove(mate)
635+
except KeyError:
636+
scalars.append(vname)
637+
645638
def _get_kp1_levels(self,k_levels):
646639
# determine kp1 levels
647640
# get borders to all k (center) levels
@@ -739,7 +732,7 @@ def _dask_array_vgrid(self, varname, klevels, k_chunksize):
739732
name = '-'.join([varname, token])
740733
dtype = self._dtype(varname)
741734

742-
nz = self.nz if _VAR_METADATA[varname]['dims'] != ['k_p1'] else self.nz+1
735+
nz = self.nz if self.var_metadata[varname]['dims'] != ['k_p1'] else self.nz+1
743736
task = (_get_1d_chunk, self.store, varname,
744737
list(klevels), nz, dtype)
745738

@@ -750,12 +743,12 @@ def _dask_array_vgrid(self, varname, klevels, k_chunksize):
750743

751744
def _get_facet_data(self, varname, iters, klevels, k_chunksize):
752745
# needs facets to be outer index of nested lists
753-
dims = _VAR_METADATA[varname]['dims']
746+
dims = self.var_metadata[varname]['dims']
754747

755748
if len(dims)==2:
756749
klevels = [0,]
757750

758-
if _is_vgrid(varname):
751+
if self._is_vgrid(varname):
759752
data_facets = self._dask_array_vgrid(varname,klevels,k_chunksize)
760753
else:
761754
data_facets = [self._dask_array(nfacet, varname, iters, klevels, k_chunksize)
@@ -797,7 +790,8 @@ def _check_iters(self, iters):
797790

798791
def get_dataset(self, varnames=None, iter_start=None, iter_stop=None,
799792
iter_step=None, iters=None, k_levels=None, k_chunksize=1,
800-
type='faces', read_grid=True, grid_vars_to_coords=True):
793+
type='faces', read_grid=True, grid_vars_to_coords=True,
794+
extra_variables=None):
801795
"""
802796
Create an xarray Dataset object for this model.
803797
@@ -827,6 +821,22 @@ def get_dataset(self, varnames=None, iter_start=None, iter_stop=None,
827821
Whether to read the grid info
828822
grid_vars_to_coords : bool, optional
829823
Whether to promote grid variables to coordinate status
824+
extra_variables : dict, optional
825+
Allow to pass variables not listed in the variables.py
826+
or in available_diagnostics.log.
827+
extra_variables must be a dict containing the variable names as keys with
828+
the corresponging values being a dict with the keys being dims and attrs.
829+
830+
Syntax:
831+
extra_variables = dict(varname = dict(dims=list_of_dims, attrs=dict(optional_attrs)))
832+
where optional_attrs can contain standard_name, long_name, units as keys
833+
834+
Example:
835+
extra_variables = dict(
836+
ADJtheta = dict(dims=['k','j','i'], attrs=dict(
837+
standard_name='Sensitivity_to_theta',
838+
long_name='Sensitivity of cost function to theta', units='[J]/degC'))
839+
)
830840
831841
Returns
832842
-------
@@ -839,6 +849,7 @@ def _if_not_none(a, b):
839849
else:
840850
return a
841851

852+
self.var_metadata = _get_var_metadata(extra_variables=extra_variables)
842853
user_iter_params = [iter_start, iter_stop, iter_step]
843854
attribute_iter_params = [self.iter_start, self.iter_stop, self.iter_step]
844855

@@ -916,30 +927,30 @@ def _if_not_none(a, b):
916927
# do separately for vertical coords on kp1_levels
917928
grid_facets = {}
918929
for vname in grid_varnames:
919-
my_k_levels = k_levels if _VAR_METADATA[vname]['dims'] !=['k_p1'] else kp1_levels
930+
my_k_levels = k_levels if self.var_metadata[vname]['dims'] !=['k_p1'] else kp1_levels
920931
grid_facets[vname] = self._get_facet_data(vname, None, my_k_levels, k_chunksize)
921932

922933
# transform it into faces or latlon
923934
data_transformers = {'faces': _all_facets_to_faces,
924935
'latlon': _all_facets_to_latlon}
925936

926937
transformer = data_transformers[type]
927-
data = transformer(data_facets, _VAR_METADATA, self.nface)
938+
data = transformer(data_facets, self.var_metadata, self.nface)
928939

929940
# separate horizontal and vertical grid variables
930941
hgrid_facets = {key: grid_facets[key]
931-
for key in grid_varnames if not _is_vgrid(key)}
942+
for key in grid_varnames if not self._is_vgrid(key)}
932943
vgrid_facets = {key: grid_facets[key]
933-
for key in grid_varnames if _is_vgrid(key)}
944+
for key in grid_varnames if self._is_vgrid(key)}
934945

935946
# do not transform vertical grid variables
936-
data.update(transformer(hgrid_facets, _VAR_METADATA, self.nface))
947+
data.update(transformer(hgrid_facets, self.var_metadata, self.nface))
937948
data.update(vgrid_facets)
938949

939950
variables = {}
940951
gridlist = ['Zl','Zu'] if read_grid else []
941952
for vname in varnames+grid_varnames:
942-
meta = _VAR_METADATA[vname]
953+
meta = self.var_metadata[vname]
943954
dims = meta['dims']
944955
if type=='faces':
945956
dims = _add_face_to_dims(dims)
@@ -958,9 +969,9 @@ def _if_not_none(a, b):
958969
if read_grid and 'RF' in grid_varnames:
959970
ki = np.array([list(kp1_levels).index(x) for x in k_levels])
960971
for zv,sl in zip(['Zl','Zu'],[ki,ki+1]):
961-
variables[zv] = xr.Variable(_VAR_METADATA[zv]['dims'],
972+
variables[zv] = xr.Variable(self.var_metadata[zv]['dims'],
962973
data['RF'][sl],
963-
_VAR_METADATA[zv]['attrs'])
974+
self.var_metadata[zv]['attrs'])
964975

965976
ds = ds.update(variables)
966977

xmitgcm/mds_store.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def open_mdsdataset(data_dir, grid_dir=None,
6060
nx=None, ny=None, nz=None,
6161
llc_method="smallchunks", extra_metadata=None,
6262
extra_variables=None,
63-
custom_grid_variables={}):
63+
custom_grid_variables=None):
6464
"""Open MITgcm-style mds (.data / .meta) file output as xarray datset.
6565
6666
Parameters
@@ -382,7 +382,7 @@ def __init__(self, data_dir, grid_dir=None,
382382
nx=None, ny=None, nz=None, llc_method="smallchunks",
383383
levels=None, extra_metadata=None,
384384
extra_variables=None,
385-
custom_grid_variables={}):
385+
custom_grid_variables=None):
386386
"""
387387
This is not a user-facing class. See open_mdsdataset for argument
388388
documentation. The only ones which are distinct are.
@@ -869,17 +869,18 @@ def _get_extra_grid_variables(grid_dir, custom_grid_variables):
869869
Then return the variable information for each of these"""
870870
extra_grid = {}
871871

872-
all_extras = {**extra_grid_variables, **custom_grid_variables}
873-
fnames = dict([[val['filename'],key] for key,val in all_extras.items() if 'filename' in val])
872+
if custom_grid_variables is not None:
873+
extra_grid_variables = extra_grid_variables.update(custom_grid_variables)
874+
fnames = dict([[val['filename'],key] for key,val in extra_grid_variables.items() if 'filename' in val])
874875

875876
all_datafiles = listdir_endswith(grid_dir, '.data')
876877
for f in all_datafiles:
877878
prefix = os.path.split(f[:-5])[-1]
878879
# Only consider what we find that matches extra/custom_grid_vars
879-
if prefix in all_extras:
880-
extra_grid[prefix] = all_extras[prefix]
880+
if prefix in extra_grid_variables:
881+
extra_grid[prefix] = extra_grid_variables[prefix]
881882
elif prefix in fnames:
882-
extra_grid[fnames[prefix]] = all_extras[fnames[prefix]]
883+
extra_grid[fnames[prefix]] = extra_grid_variables[fnames[prefix]]
883884

884885
return extra_grid
885886

0 commit comments

Comments
 (0)