11import numpy as np
22import xarray as xr
33import xcdat
4+ import xesmf
45
56
67def _calculate_2d_cell_bounds (
7- dimension : xr . DataArray ,
8+ points : np . ndarray ,
89 i : int ,
910 j : int ,
10- ) -> [ float , float , float , float ]:
11- cell_center = dimension [j , i ]. data
11+ ) -> list [ float ]:
12+ cell_center = points [j , i ]
1213 if i == 0 :
13- di = dimension [j , i + 1 ]. data - cell_center
14+ di = points [j , i + 1 ] - cell_center
1415 else :
15- di = cell_center - dimension [j , i - 1 ]. data
16+ di = cell_center - points [j , i - 1 ]
1617 if j == 0 :
17- dj = dimension [j + 1 , i ]. data - cell_center
18+ dj = points [j + 1 , i ] - cell_center
1819 else :
19- dj = cell_center - dimension [j - 1 , i ]. data
20+ dj = cell_center - points [j - 1 , i ]
2021
2122 return np .asarray (
2223 [
@@ -43,22 +44,20 @@ def decimate_rectilinear(dataset: xr.Dataset) -> xr.Dataset:
4344 """
4445 # Decimate the dataset, but update the bounds
4546 # 10x10 degree grid
46- regridded_vars = []
47-
48- for data_var in dataset .data_vars :
49- # Some datasets don't correctly use data_vars
50- if "_bnds" in data_var :
51- continue
52- output_grid = xcdat .create_uniform_grid (- 90 , 90 , 10 , 0 , 359 , 10 )
53- regridded_vars .append (
54- dataset .regridder .horizontal (
55- data_var ,
56- output_grid = output_grid ,
57- tool = "xesmf" ,
58- method = "bilinear" ,
59- )
60- )
61- return xr .merge (regridded_vars )
47+ output_grid = xcdat .create_uniform_grid (- 90 , 90 , 10 , 0 , 359 , 10 )
48+ regrid = xesmf .Regridder (dataset , output_grid , "bilinear" , periodic = True )
49+ result = regrid (dataset .copy ())
50+ result = result .bounds .add_bounds ("Y" ).bounds .add_bounds ("X" )
51+ # Restore attributes and add dataarrays that have not been regridded.
52+ for k , v in dataset .data_vars .items ():
53+ if k in result :
54+ result [k ].attrs = v .attrs
55+ else :
56+ result [k ] = v
57+ for k , v in dataset .coords .items ():
58+ result [k ].attrs = v .attrs
59+ result .attrs = dataset .attrs
60+ return result
6261
6362
6463def decimate_curvilinear (dataset : xr .Dataset , factor : int = 10 ) -> xr .Dataset :
@@ -82,13 +81,15 @@ def decimate_curvilinear(dataset: xr.Dataset, factor: int = 10) -> xr.Dataset:
8281 """
8382 assert factor >= 1
8483 result = dataset .interp (i = dataset .i [::factor ]).interp (j = dataset .j [::factor ])
85- result .coords ["i" ].values [:] = range (len (result .i ))
86- result .coords ["j" ].values [:] = range (len (result .j ))
84+ result .coords ["i" ].values [:] = np . arange (len (result .i ))
85+ result .coords ["j" ].values [:] = np . arange (len (result .j ))
8786
8887 # Update the bounds of the cells
88+ latitude_points = result .latitude .values
89+ longitude_points = result .longitude .values
8990 for j in result .j :
9091 for i in result .i :
91- result .vertices_latitude [j , i ] = _calculate_2d_cell_bounds (result . latitude , i , j )
92- result .vertices_longitude [j , i ] = _calculate_2d_cell_bounds (result . longitude , i , j )
92+ result .vertices_latitude [j , i ] = _calculate_2d_cell_bounds (latitude_points , i , j )
93+ result .vertices_longitude [j , i ] = _calculate_2d_cell_bounds (longitude_points , i , j )
9394
9495 return result
0 commit comments