Skip to content

Commit 1fd135a

Browse files
committed
lint
1 parent a2c14a1 commit 1fd135a

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

extract_model/extract_model.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,9 @@ def select(
281281
* False: 2D array of points with 1 dimension the lons and the other dimension the lats.
282282
* True: lons/lats as unstructured coordinate pairs (in xESMF language, LocStream).
283283
locstreamT: boolean, optional
284-
If False, interpolate in time dimension independently of horizontal points. If True, use advanced indexing/interpolation in xarray to interpolate times to each horizontal locstream point. If this is True, locstream must be True.
284+
If False, interpolate in time dimension independently of horizontal points. If True, use advanced indexing/interpolation in xarray to interpolate times to each horizontal locstream point.
285285
locstreamZ: boolean, optional
286-
If False, interpolate in depth dimension independently of horizontal points. If True, use advanced indexing after depth interpolation select depths to match each horizontal locstream point. If this is True, locstream must be True and locstreamT must be True.
286+
If False, interpolate in depth dimension independently of horizontal points. If True, use advanced indexing after depth interpolation select depths to match each horizontal locstream point.
287287
new_dim : str
288288
This is the name of the new dimension created if we are interpolating to a new set of points that are not a grid.
289289
weights: xESMF netCDF file path, DataArray, optional
@@ -360,14 +360,15 @@ def select(
360360
"Use extrap=True to extrapolate."
361361
)
362362

363-
if locstreamT:
364-
if not locstream:
365-
raise ValueError("if `locstreamT` is True, `locstream` must also be True.")
366-
if locstreamZ:
367-
if not locstream or not locstreamT:
368-
raise ValueError(
369-
"if `locstreamZ` is True, `locstream` and `locstreamT` must also be True."
370-
)
363+
# these are only true if interpolating in those directions too — need to fix them
364+
# if locstreamT:
365+
# if not locstream:
366+
# raise ValueError("if `locstreamT` is True, `locstream` must also be True.")
367+
# if locstreamZ:
368+
# if not locstream or not locstreamT:
369+
# raise ValueError(
370+
# "if `locstreamZ` is True, `locstream` and `locstreamT` must also be True."
371+
# )
371372

372373
# Perform interpolation
373374
if horizontal_interp:
@@ -443,13 +444,12 @@ def select(
443444
xs, ys = proj(xs, ys)
444445
x, y = proj(longitude, latitude)
445446

446-
# import pdb; pdb.set_trace()
447447
# lam = calc_barycentric(x, y, xs.reshape((10,9,3)), ys.reshape((10,9,3)))
448448
lam = calc_barycentric(x.flatten(), y.flatten(), xs, ys)
449449
# lam = calc_barycentric(x, y, xs, ys)
450450
# interp_coords are the coords and indices that went into the interpolation
451451
da, interp_coords = interp_with_barycentric(da, ixs, iys, lam)
452-
# import pdb; pdb.set_trace()
452+
453453
# if not locstream:
454454
# FIGURE OUT HOW TO RECONSTITUTE INTO GRID HERE
455455
kwargs_out["interp_coords"] = interp_coords
@@ -665,6 +665,7 @@ def pt_in_itriangle_proj(ix, iy):
665665

666666
# advanced indexing to select all assuming coherent time series
667667
# make sure len of each dimension matches
668+
668669
if locstreamZ:
669670

670671
dims_to_index = [da.cf["T"].name]
@@ -809,7 +810,6 @@ def sel2d(
809810
mask = mask.load()
810811

811812
# Assume mask is 2D — but not true for wetting/drying
812-
# import pdb; pdb.set_trace()
813813
# find indices representing mask
814814
eta, xi = np.where(mask.values)
815815

@@ -898,6 +898,10 @@ def sel2d(
898898

899899
else:
900900

901+
# make sure the mask matches
902+
msg = f"Mask {mask.name} dimensions do not match horizontal var {var.name} dimensions. mask dims: {mask.dims}, var dims: {var.dims}"
903+
assert len(set(mask.dims) - set(var.dims)) == 0, msg
904+
901905
# currently lons, lats 1D only
902906

903907
# if no mask, assume user just wants 1 nearest point to each input lons/lats pair
@@ -907,7 +911,7 @@ def sel2d(
907911
# if user inputs mask, use it to only return the nearest point that is active
908912
# so, find nearest 30 points to have options
909913
else:
910-
k = 30
914+
k = 50
911915

912916
distances, (iys, ixs) = tree_query(var[lonname], var[latname], lons, lats, k=k)
913917

@@ -916,7 +920,7 @@ def sel2d(
916920
raise ValueError("all found values are masked!")
917921

918922
if mask is not None:
919-
isorted_mask = np.argsort(-mask.values[iys, ixs], axis=-1)
923+
isorted_mask = np.argsort(-mask.values[iys, ixs], axis=-1, kind="mergesort")
920924
# sort the ixs and iys according to this sorting so that if there are unmasked indices,
921925
# they are leftmost also, and we will use the leftmost values.
922926
ixs_brought_along = np.take_along_axis(ixs, isorted_mask, axis=1)

0 commit comments

Comments
 (0)