Skip to content

Commit 4d2cc3b

Browse files
authored
Fix bug in _map_dims_to_ugrid, use Polars to improve SCRIP reader performance (#1109)
* remove check for n_edge which was constructing connectivity * use polars for unique calls in SCRIP reader
1 parent c30f0b0 commit 4d2cc3b

File tree

6 files changed

+38
-25
lines changed

6 files changed

+38
-25
lines changed

ci/asv.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies:
1717
- pandas
1818
- pathlib
1919
- pre_commit
20+
- polars
2021
- pyarrow
2122
- pytest
2223
- pytest-cov

ci/docs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ dependencies:
3434
- pandas
3535
- geocat-datafiles
3636
- spatialpandas
37+
- polars
3738
- geopandas
3839
- pip:
3940
- antimeridian

ci/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies:
2121
- pandas
2222
- pathlib
2323
- pre_commit
24+
- polars
2425
- pyarrow
2526
- pytest
2627
- pip

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ dependencies = [
4141
"geopandas",
4242
"xarray",
4343
"hvplot",
44+
"polars",
4445
]
4546
# minimal dependencies end
4647

uxarray/core/utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,21 @@ def _map_dims_to_ugrid(
3939
# drop dimensions not present in the original dataset
4040
_source_dims_dict.pop(key)
4141

42+
# only check edge dimension if it is present (to avoid overhead of computing connectivity)
43+
if "n_edge" in grid._ds.dims:
44+
n_edge = grid._ds.sizes["n_edge"]
45+
else:
46+
n_edge = None
47+
4248
for dim in set(ds.dims) ^ _source_dims_dict.keys():
4349
# obtain dimensions that were not parsed source_dims_dict and attempt to match to a grid element
4450
if ds.sizes[dim] == grid.n_face:
4551
_source_dims_dict[dim] = "n_face"
4652
elif ds.sizes[dim] == grid.n_node:
4753
_source_dims_dict[dim] = "n_node"
48-
elif ds.sizes[dim] == grid.n_edge:
49-
_source_dims_dict[dim] = "n_edge"
50-
51-
# Possible Issue: https://github.com/UXARRAY/uxarray/issues/610
54+
elif n_edge is not None:
55+
if ds.sizes[dim] == n_edge:
56+
_source_dims_dict[dim] = "n_edge"
5257

5358
# rename dimensions to follow the UGRID conventions
5459
ds = ds.swap_dims(_source_dims_dict)

uxarray/io/_scrip.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import xarray as xr
22
import numpy as np
33

4+
import polars as pl
5+
46
from uxarray.grid.connectivity import _replace_fill_values
57
from uxarray.constants import INT_DTYPE, INT_FILL_VALUE
68

@@ -11,43 +13,45 @@ def _to_ugrid(in_ds, out_ds):
1113
"""If input dataset (``in_ds``) file is an unstructured SCRIP file,
1214
function will reassign SCRIP variables to UGRID conventions in output file
1315
(``out_ds``).
14-
15-
Parameters
16-
----------
17-
in_ds : xarray.Dataset
18-
Original scrip dataset of interest being used
19-
20-
out_ds : xarray.Variable
21-
file to be returned by ``_populate_scrip_data``, used as an empty placeholder file
22-
to store reassigned SCRIP variables in UGRID conventions
2316
"""
2417

2518
source_dims_dict = {}
2619

2720
if in_ds["grid_area"].all():
2821
# Create node_lon & node_lat variables from grid_corner_lat/lon
29-
# Turn latitude scrip array into 1D instead of 2D
22+
# Turn latitude and longitude scrip arrays into 1D
3023
corner_lat = in_ds["grid_corner_lat"].values.ravel()
31-
32-
# Repeat above steps with longitude data instead
3324
corner_lon = in_ds["grid_corner_lon"].values.ravel()
3425

35-
# Combine flat lat and lon arrays
36-
corner_lon_lat = np.vstack((corner_lon, corner_lat)).T
26+
# Use Polars to find unique coordinate pairs
27+
df = pl.DataFrame({"lon": corner_lon, "lat": corner_lat}).with_row_count(
28+
"original_index"
29+
)
30+
31+
# Get unique rows (first occurrence). This preserves the order in which they appear.
32+
unique_df = df.unique(subset=["lon", "lat"], keep="first")
33+
34+
# unq_ind: The indices of the unique rows in the original array
35+
unq_ind = unique_df["original_index"].to_numpy().astype(INT_DTYPE)
36+
37+
# To get the inverse index (unq_inv): map each original row back to its unique row index.
38+
# Add a unique_id to the unique_df which will serve as the "inverse" mapping.
39+
unique_df = unique_df.with_row_count("unique_id")
3740

38-
# Run numpy unique to determine which rows/values are actually unique
39-
_, unq_ind, unq_inv = np.unique(
40-
corner_lon_lat, return_index=True, return_inverse=True, axis=0
41+
# Join original df with unique_df to find out which unique_id corresponds to each row
42+
df_joined = df.join(
43+
unique_df.drop("original_index"), on=["lon", "lat"], how="left"
4144
)
45+
unq_inv = df_joined["unique_id"].to_numpy().astype(INT_DTYPE)
4246

43-
# Now, calculate unique lon and lat values to account for 'node_lon' and 'node_lat'
44-
unq_lon = corner_lon_lat[unq_ind, :][:, 0]
45-
unq_lat = corner_lon_lat[unq_ind, :][:, 1]
47+
# Extract unique lon and lat values using unq_ind
48+
unq_lon = corner_lon[unq_ind]
49+
unq_lat = corner_lat[unq_ind]
4650

4751
# Reshape face nodes array into original shape for use in 'face_node_connectivity'
4852
unq_inv = np.reshape(unq_inv, (len(in_ds.grid_size), len(in_ds.grid_corners)))
4953

50-
# Create node_lon & node_lat from unsorted, unique grid_corner_lat/lon
54+
# Create node_lon & node_lat
5155
out_ds[ugrid.NODE_COORDINATES[0]] = xr.DataArray(
5256
unq_lon, dims=[ugrid.NODE_DIM], attrs=ugrid.NODE_LON_ATTRS
5357
)

0 commit comments

Comments
 (0)