11import numpy as np
22import pytest
3+ import shapely
4+ import shapely .testing
35import xarray as xr
46
57from grid_indexing import grids
68
79
8- class TestInferGridType :
9- def test_rectilinear_1d (self ):
10- lat = xr .Variable ("lat" , np .linspace (- 10 , 10 , 3 ), {"standard_name" : "latitude" })
11- lon = xr .Variable ("lon" , np .linspace (- 5 , 5 , 4 ), {"standard_name" : "longitude" })
12- ds = xr .Dataset (coords = {"lat" : lat , "lon" : lon })
13-
14- actual = grids .infer_grid_type (ds )
15- assert actual == "1d-rectilinear"
16-
17- def test_rectilinear_2d (self ):
18- lat_ , lon_ = np .meshgrid (np .linspace (- 10 , 10 , 3 ), np .linspace (- 5 , 5 , 4 ))
19- lat = xr .Variable (["y" , "x" ], lat_ , {"standard_name" : "latitude" })
20- lon = xr .Variable (["y" , "x" ], lon_ , {"standard_name" : "longitude" })
21- ds = xr .Dataset (coords = {"lat" : lat , "lon" : lon })
22-
23- actual = grids .infer_grid_type (ds )
24- assert actual == "2d-rectilinear"
25-
26- def test_curvilinear_2d (self ):
27- lat_ = np .array ([[0 , 0 , 0 , 0 ], [1 , 1 , 1 , 1 ], [2 , 2 , 2 , 2 ]])
28- lon_ = np .array ([[0 , 1 , 2 , 3 ], [1 , 2 , 3 , 4 ], [2 , 3 , 4 , 5 ]])
29-
30- lat = xr .Variable (["y" , "x" ], lat_ , {"standard_name" : "latitude" })
31- lon = xr .Variable (["y" , "x" ], lon_ , {"standard_name" : "longitude" })
32- ds = xr .Dataset (coords = {"lat" : lat , "lon" : lon })
33-
34- actual = grids .infer_grid_type (ds )
35- assert actual == "2d-curvilinear"
36-
37- def test_unstructured_1d (self ):
38- lat = xr .Variable (
39- "cells" , np .linspace (- 10 , 10 , 12 ), {"standard_name" : "latitude" }
40- )
41- lon = xr .Variable (
42- "cells" , np .linspace (- 5 , 5 , 12 ), {"standard_name" : "longitude" }
43- )
44- ds = xr .Dataset (coords = {"lat" : lat , "lon" : lon })
45-
46- actual = grids .infer_grid_type (ds )
47-
48- assert actual == "1d-unstructured"
10+ def example_dataset (grid_type ):
11+ match grid_type :
12+ case "1d-rectilinear" :
13+ lat_ = np .array ([0 , 2 ])
14+ lon_ = np .array ([0 , 2 , 4 ])
15+ lat = xr .Variable ("lat" , lat_ , {"standard_name" : "latitude" })
16+ lon = xr .Variable ("lon" , lon_ , {"standard_name" : "longitude" })
17+ ds = xr .Dataset (coords = {"lat" : lat , "lon" : lon })
18+ case "2d-rectilinear" :
19+ lat_ , lon_ = np .meshgrid (np .array ([0 , 2 ]), np .array ([0 , 2 , 4 ]))
20+ lat = xr .Variable (["y" , "x" ], lat_ , {"standard_name" : "latitude" })
21+ lon = xr .Variable (["y" , "x" ], lon_ , {"standard_name" : "longitude" })
22+ ds = xr .Dataset (coords = {"lat" : lat , "lon" : lon })
23+ case "2d-curvilinear" :
24+ lat_ = np .array ([[0 , 0 , 0 ], [2 , 2 , 2 ]])
25+ lon_ = np .array ([[0 , 2 , 4 ], [2 , 4 , 6 ]])
26+
27+ lat = xr .Variable (["y" , "x" ], lat_ , {"standard_name" : "latitude" })
28+ lon = xr .Variable (["y" , "x" ], lon_ , {"standard_name" : "longitude" })
29+ ds = xr .Dataset (coords = {"lat" : lat , "lon" : lon })
30+ case "1d-unstructured" :
31+ lat_ = np .arange (12 )
32+ lon_ = np .arange (- 5 , 7 )
33+ lat = xr .Variable ("cells" , lat_ , {"standard_name" : "latitude" })
34+ lon = xr .Variable ("cells" , lon_ , {"standard_name" : "longitude" })
35+ ds = xr .Dataset (coords = {"lat" : lat , "lon" : lon })
36+ case "2d-crs" :
37+ data = np .linspace (- 10 , 10 , 12 ).reshape (3 , 4 )
38+ geo_transform = (
39+ "101985.0 300.0379266750948 0.0 2826915.0 0.0 -300.041782729805"
40+ )
41+
42+ attrs = {
43+ "grid_mapping_name" : "transverse_mercator" ,
44+ "GeoTransform" : geo_transform ,
45+ }
4946
50- def test_crs_2d (self ):
51- data = np .linspace (- 10 , 10 , 12 ).reshape (3 , 4 )
52- geo_transform = "101985.0 300.0379266750948 0.0 2826915.0 0.0 -300.041782729805"
47+ ds = xr .Dataset (
48+ {"band_data" : (["y" , "x" ], data )},
49+ coords = {"spatial_ref" : ((), np .array (0 ), attrs )},
50+ )
51+
52+ return ds
53+
54+
55+ def example_geometries (grid_type ):
56+ if grid_type == "2d-crs" :
57+ raise NotImplementedError
58+
59+ match grid_type :
60+ case "1d-rectilinear" :
61+ boundaries = np .array (
62+ [
63+ [
64+ [[- 1 , - 1 ], [- 1 , 1 ], [1 , 1 ], [1 , - 1 ]],
65+ [[- 1 , 1 ], [- 1 , 3 ], [1 , 3 ], [1 , 1 ]],
66+ ],
67+ [
68+ [[1 , - 1 ], [1 , 1 ], [3 , 1 ], [3 , - 1 ]],
69+ [[1 , 1 ], [1 , 3 ], [3 , 3 ], [3 , 1 ]],
70+ ],
71+ [
72+ [[3 , - 1 ], [3 , 1 ], [5 , 1 ], [5 , - 1 ]],
73+ [[3 , 1 ], [3 , 3 ], [5 , 3 ], [5 , 1 ]],
74+ ],
75+ ]
76+ )
77+ case "2d-rectilinear" :
78+ boundaries = np .array (
79+ [
80+ [
81+ [[- 1 , - 1 ], [- 1 , 1 ], [1 , 1 ], [1 , - 1 ]],
82+ [[- 1 , 1 ], [- 1 , 3 ], [1 , 3 ], [1 , 1 ]],
83+ ],
84+ [
85+ [[1 , - 1 ], [1 , 1 ], [3 , 1 ], [3 , - 1 ]],
86+ [[1 , 1 ], [1 , 3 ], [3 , 3 ], [3 , 1 ]],
87+ ],
88+ [
89+ [[3 , - 1 ], [3 , 1 ], [5 , 1 ], [5 , - 1 ]],
90+ [[3 , 1 ], [3 , 3 ], [5 , 3 ], [5 , 1 ]],
91+ ],
92+ ]
93+ )
94+ case "2d-curvilinear" :
95+ boundaries = np .array (
96+ [
97+ [
98+ [[- 2 , - 1 ], [0 , - 1 ], [2 , 1 ], [0 , 1 ]],
99+ [[0 , - 1 ], [2 , - 1 ], [4 , 1 ], [2 , 1 ]],
100+ [[2 , - 1 ], [4 , - 1 ], [6 , 1 ], [4 , 1 ]],
101+ ],
102+ [
103+ [[0 , 1 ], [2 , 1 ], [4 , 3 ], [2 , 3 ]],
104+ [[2 , 1 ], [4 , 1 ], [6 , 3 ], [4 , 3 ]],
105+ [[4 , 1 ], [6 , 1 ], [8 , 3 ], [6 , 3 ]],
106+ ],
107+ ]
108+ )
109+
110+ return shapely .polygons (boundaries )
53111
54- ds = xr .Dataset (
55- {"band_data" : (["y" , "x" ], data )},
56- coords = {
57- "spatial_ref" : (
58- (),
59- np .array (0 ),
60- {
61- "grid_mapping_name" : "transverse_mercator" ,
62- "GeoTransform" : geo_transform ,
63- },
64- )
65- },
66- )
67112
113+ class TestInferGridType :
114+ @pytest .mark .parametrize (
115+ "grid_type" ,
116+ [
117+ "1d-rectilinear" ,
118+ "2d-rectilinear" ,
119+ "2d-curvilinear" ,
120+ "1d-unstructured" ,
121+ "2d-crs" ,
122+ ],
123+ )
124+ def test_infer_grid_type (self , grid_type ):
125+ ds = example_dataset (grid_type )
68126 actual = grids .infer_grid_type (ds )
69- assert actual == "2d-crs"
127+ assert actual == grid_type
70128
71129 def test_missing_spatial_coordinates (self ):
72130 ds = xr .Dataset ()
@@ -86,3 +144,46 @@ def test_unknown_grid_type(self):
86144
87145 with pytest .raises (ValueError , match = "unable to infer the grid type" ):
88146 grids .infer_grid_type (ds )
147+
148+
149+ class TestInferCellGeometries :
150+ @pytest .mark .parametrize (
151+ ["grid_type" , "error" , "pattern" ],
152+ (
153+ pytest .param ("2d-crs" , NotImplementedError , "geotransform" , id = "2d-crs" ),
154+ pytest .param (
155+ "1d-unstructured" ,
156+ ValueError ,
157+ "unstructured grids" ,
158+ id = "1d-unstructured" ,
159+ ),
160+ ),
161+ )
162+ def test_not_supported (self , grid_type , error , pattern ):
163+ ds = example_dataset (grid_type )
164+ with pytest .raises (error , match = pattern ):
165+ grids .infer_cell_geometries (ds )
166+
167+ def test_infer_coords (self ):
168+ ds = xr .Dataset ()
169+ with pytest .raises (ValueError , match = "cannot infer geographic coordinates" ):
170+ grids .infer_cell_geometries (ds , grid_type = "2d-rectilinear" )
171+
172+ @pytest .mark .parametrize (
173+ "grid_type" ,
174+ [
175+ "1d-rectilinear" ,
176+ "2d-rectilinear" ,
177+ "2d-curvilinear" ,
178+ pytest .param (
179+ "2d-crs" , marks = pytest .mark .xfail (reason = "not yet implemented" )
180+ ),
181+ ],
182+ )
183+ def test_infer_geoms (self , grid_type ):
184+ ds = example_dataset (grid_type )
185+ expected = example_geometries (grid_type )
186+
187+ actual = grids .infer_cell_geometries (ds , grid_type = grid_type )
188+
189+ shapely .testing .assert_geometries_equal (actual , expected )
0 commit comments