@@ -50,7 +50,7 @@ def project(f: h5py.File) -> Iterator[str]:
5050
5151from absl import logging
5252import numpy
53- from sklearn import neighbors
53+ from scipy import spatial
5454from sklearn .metrics import pairwise
5555
5656from stac import bboxes
@@ -75,94 +75,31 @@ class ProjectionError(Error):
7575 """Error raised when we fail to project a coordinate into lat-lon."""
7676
7777
78- @dataclasses .dataclass (frozen = True )
79- class S1Interval :
80- """An S1Interval represents a closed interval on a unit circle.
81-
82- Properties:
83- low: Minimum point in degrees.
84- high: Maximum point in degrees. If high < low, then the interval is
85- inverted. This can be used to detect antimeridian crossings when the
86- input points are longitudes.
87- """
88-
89- low : float
90- high : float
91-
92- @classmethod
93- def empty (cls ) -> Self :
94- """Returns an empty S1Interval."""
95- return cls (180 , - 180 )
96-
97- def is_empty (self ) -> bool :
98- """Returns True if the S1Interval is empty."""
99- return self .low == 180 and self .high == - 180
100-
101- def contains (self , longitude : float ) -> bool :
102- """Returns True if the given longitude is contained by the interval."""
103- if self .low > self .high :
104- return longitude >= self .low or longitude <= self .high
105- else :
106- return longitude >= self .low and longitude <= self .high
107-
108- @classmethod
109- def positive_distance (cls , degrees_a : float , degrees_b : float ) -> float :
110- """Returns the distance between two points in the range [0, 360)."""
111- diff = degrees_b - degrees_a
112- if diff >= 0 :
113- return diff
114- else :
115- # If b is 180 and a is -180 + epsilon, we'd prefer to return 360.
116- return (degrees_b + 180 ) - (degrees_a - 180 )
117-
118- @classmethod
119- def check (
120- cls , degrees_low : float , degrees_high : float
121- ) -> tuple [float , float ]:
122- """Returns points so that the low-high interval is on a unit circle."""
123- if degrees_low == - 180 and degrees_high != 180 :
124- degrees_low = math .pi
125- if degrees_high == - 180 and degrees_low != 180 :
126- degrees_high = math .pi
127- return degrees_low , degrees_high
128-
129- def add_longitude (self , lon : float ) -> Self :
130- """Returns an S1Interval that includes `lon`."""
131- if math .fabs (lon ) > 180 :
132- raise InputError ('Cannot add latitude %f to S1Interval' .format (lon ))
133- if lon == - 180 :
134- lon = 180
135-
136- if self .is_empty ():
137- return S1Interval (lon , lon )
138- elif self .contains (lon ):
139- return self
140- else :
141- dist_low = self .positive_distance (lon , self .low )
142- dist_high = self .positive_distance (self .high , lon )
143- if dist_low < dist_high :
144- return S1Interval (* self .check (lon , self .high ))
145- else :
146- return S1Interval (* self .check (self .low , lon ))
78+ def latlon_to_xyz (
79+ lat_deg : Union [float , numpy .ndarray ], lon_deg : Union [float , numpy .ndarray ]
80+ ) -> numpy .ndarray :
81+ """Converts lat/lon in degrees to Cartesian (x, y, z) on unit sphere."""
82+ lat = numpy .radians (lat_deg )
83+ lon = numpy .radians (lon_deg )
84+ x = numpy .cos (lat ) * numpy .cos (lon )
85+ y = numpy .cos (lat ) * numpy .sin (lon )
86+ z = numpy .sin (lat )
87+ return numpy .stack ((x , y , z ), axis = - 1 )
14788
14889
14990@dataclasses .dataclass (frozen = True )
15091class CoordinateIndex :
15192 """An index of (lat, lon) coordinates to their array offsets in 2D rasters.
15293
15394 Properties:
154- points: A list of (latitude, longitude) pairs.
155- point_index: A map from (latitude, longitude) to the (i, j) position in the
156- original 2D rasters.
157- bbox_list: The bounding boxes that cover all of `points`. Generally there
158- will only be one, but there will be two if the region crosses the
159- antimeridian, one for each side. Note that if 'points' represent pixel
160- centers, then the bounding boxes will not cover the entire border.
95+ points: A numpy array of (lat, lon) points found in the source rasters.
96+ bbox_list: The bounding boxes that cover all relevant coordinate points.
97+ source_indices: A numpy array of (i, j) offsets into the original rasters.
16198 """
16299
163- points : list [tuple [float , float ]]
164- point_index : dict [tuple [float , float ], tuple [int , int ]]
100+ points : numpy .ndarray
165101 bbox_list : list [bboxes .BBox ]
102+ source_indices : numpy .ndarray
166103
167104 @classmethod
168105 def from_arrays (
@@ -211,83 +148,75 @@ def from_arrays(
211148 )
212149 logging .info ('Source rasters have shape %s' , lat .shape )
213150
214- # Build a bounding box in each hemisphere. This will matter if we cross the
215- # antimeridian; we don't want a single box that spans from (-180, 180) with
216- # a lot of empty pixels in the middle.
217- west_bbox = None
218- east_bbox = None
219- s1_interval = S1Interval .empty ()
220- if mask is None :
221- mask = numpy .full ((lat .shape [0 ], lat .shape [1 ]), 1 , dtype = numpy .uint8 )
222-
223- points = []
224- point_index = {}
225- for i , (lat_col , lon_col , mask_col ) in enumerate (zip (lat , lon , mask )):
226- for j , (lat_ij , lon_ij , mask_ij ) in enumerate (
227- zip (lat_col , lon_col , mask_col )
228- ):
229- if (
230- lat_ij == lat_fill_value
231- or lon_ij == lon_fill_value
232- # == doesn't work for numpy.nan
233- or (numpy .isnan (lat_ij ) and numpy .isnan (lat_fill_value ))
234- or (numpy .isnan (lon_ij ) and numpy .isnan (lon_fill_value ))
235- ):
236- continue
237- lat_ij = lat_ij .item ()
238- lon_ij = lon_ij .item ()
239- if mask_ij != 0 :
240- points .append ((lat_ij , lon_ij ))
241- point_index [(lat_ij , lon_ij )] = (i , j )
242- s1_interval = s1_interval .add_longitude (lon_ij )
243-
244- # Update the appropriate bounding box depending on the hemisphere of
245- # this point.
246- if lon_ij <= 0 :
247- if west_bbox is None :
248- west_bbox = bboxes .BBox (lon_ij , lat_ij , lon_ij , lat_ij )
249- this_bbox = west_bbox
250- else :
251- if east_bbox is None :
252- east_bbox = bboxes .BBox (lon_ij , lat_ij , lon_ij , lat_ij )
253- this_bbox = east_bbox
254- if lat_ij < this_bbox .south :
255- this_bbox .south = lat_ij
256- if lat_ij > this_bbox .north :
257- this_bbox .north = lat_ij
258- if lon_ij < this_bbox .west :
259- this_bbox .west = lon_ij
260- if lon_ij > this_bbox .east :
261- this_bbox .east = lon_ij
151+ # Coords that are not fill values are used for BBox calculation.
152+ coords_valid = (lat != lat_fill_value ) & (lon != lon_fill_value )
153+ if numpy .isnan (lat_fill_value ):
154+ coords_valid &= ~ numpy .isnan (lat )
155+ if numpy .isnan (lon_fill_value ):
156+ coords_valid &= ~ numpy .isnan (lon )
262157
263- bbox_list = []
264- if s1_interval .is_empty ():
158+ if not numpy .any (coords_valid ):
265159 raise EmptyInputError ('The input grid was empty' )
266- elif s1_interval .low > s1_interval .high :
267- # If the S1Interval is inverted, then we crossed the antimeridian.
268- if east_bbox is None or west_bbox is None :
269- raise ProjectionError (
270- 'The antimeridian was crossed but one hemisphere has a null bbox'
271- )
272- bbox_list .extend ((east_bbox , west_bbox ))
160+
161+ valid_lat = lat [coords_valid ]
162+ valid_lon = lon [coords_valid ]
163+
164+ # Minimal enclosing interval on S1 to detect antimeridian crossing.
165+ lons_to_check = valid_lon .copy ()
166+ lons_to_check [lons_to_check == - 180 ] = 180
167+ lons_unique = numpy .unique (lons_to_check )
168+ if lons_unique .size == 1 :
169+ s1_low , s1_high = float (lons_unique [0 ]), float (lons_unique [0 ])
273170 else :
274- if east_bbox is not None :
275- if west_bbox is not None :
276- # We crossed the prime meridian, which is fine.
277- bbox_list .append (east_bbox .union (west_bbox ))
278- else :
279- bbox_list .append (east_bbox )
280- elif west_bbox is not None :
281- bbox_list .append (west_bbox )
171+ gaps = numpy .diff (lons_unique )
172+ wrap_gap = 360.0 - (lons_unique [- 1 ] - lons_unique [0 ])
173+ max_gap_idx = numpy .argmax (gaps )
174+ if wrap_gap >= gaps [max_gap_idx ]:
175+ s1_low , s1_high = float (lons_unique [0 ]), float (lons_unique [- 1 ])
282176 else :
283- raise EmptyInputError ('The input grid was empty' )
177+ s1_low , s1_high = (
178+ float (lons_unique [max_gap_idx + 1 ]),
179+ float (lons_unique [max_gap_idx ]),
180+ )
181+
182+ # Points for GLT are also filtered by `mask`.
183+ glt_valid = coords_valid .copy ()
184+ if mask is not None :
185+ glt_valid = glt_valid & (mask != 0 )
186+
187+ if not numpy .any (glt_valid ):
188+ raise EmptyInputError ('The input grid had no unmasked points' )
189+
190+ idx_i , idx_j = numpy .where (glt_valid )
191+ points = numpy .stack ([lat [glt_valid ], lon [glt_valid ]], axis = - 1 )
192+ source_indices = numpy .stack ([idx_i , idx_j ], axis = - 1 )
193+
194+ # Build a bounding box in each hemisphere.
195+ is_west = valid_lon <= 0
196+ is_east = ~ is_west
197+
198+ bbox_list = []
199+ if numpy .any (is_west ):
200+ w_lats = valid_lat [is_west ]
201+ w_lons = valid_lon [is_west ]
202+ bbox_list .append (
203+ bboxes .BBox (w_lons .min (), w_lats .min (), w_lons .max (), w_lats .max ())
204+ )
284205
285- if not points :
286- raise EmptyInputError (
287- 'No points mapped by CoordinateIndex with bounding boxes: %s' ,
288- bbox_list ,
206+ if numpy .any (is_east ):
207+ e_lats = valid_lat [is_east ]
208+ e_lons = valid_lon [is_east ]
209+ bbox_list .insert (
210+ 0 , bboxes .BBox (e_lons .min (), e_lats .min (), e_lons .max (), e_lats .max ())
289211 )
290- return cls (points , point_index , bbox_list )
212+
213+ if s1_low <= s1_high :
214+ # If we didn't cross the antimeridian, union the boxes. In this case, we
215+ # crossed the prime meridian, but that's fine.
216+ if len (bbox_list ) == 2 :
217+ bbox_list = [bbox_list [0 ].union (bbox_list [1 ])]
218+
219+ return cls (points , bbox_list , source_indices )
291220
292221
293222@dataclasses .dataclass (frozen = True )
@@ -357,17 +286,19 @@ def from_index(
357286
358287 # When filling in the corrected grid, we pick a pixel by finding the nearest
359288 # original point to the given position. We fill any gaps by choosing the
360- # nearest neighbor as long as it is not too many pixels away.
361- tree = neighbors .BallTree (
362- [(math .radians (x ), math .radians (y )) for x , y in index .points ],
363- metric = 'haversine' ,
364- leaf_size = 10 ,
365- )
366- max_distance = max_nn_distance * max (
289+ # nearest neighbor as long as it is not too many pixels away. We use the
290+ # Euclidean distance because it's faster than Haversine and roughly
291+ # equivalent.
292+ source_xyz = latlon_to_xyz (index .points [:, 0 ], index .points [:, 1 ])
293+ tree = spatial .cKDTree (source_xyz , leafsize = 10 )
294+
295+ # Convert the angular max distance to chord distance.
296+ max_theta = max_nn_distance * max (
367297 pairwise .haversine_distances (
368298 [[0 , 0 ], [math .radians (scale_lat ), math .radians (scale_lon )]]
369299 )[0 ]
370300 )
301+ max_chord = 2 * math .sin (max_theta / 2 )
371302
372303 tables = []
373304 for bbox in index .bbox_list :
@@ -379,53 +310,66 @@ def from_index(
379310 north = max (- 90 , min (90 , bbox .north - (scale_lat / 2 ))),
380311 )
381312
382- # Preallocate the GLTs so we only have to do assignment below.
383- # This should speed things up when the size of the grid is large, which
384- # tends to happen the farther you get from the equator.
313+ # Preallocate the GLT as a single numpy array.
385314 num_cols = int (math .ceil ((bbox .south - bbox .north ) / scale_lat ))
386315 num_rows = int (math .ceil ((bbox .east - bbox .west ) / scale_lon ))
387- glt_i = [[GLT_FILL_VALUE ] * num_rows for _ in range (0 , num_cols )]
388- glt_j = [[GLT_FILL_VALUE ] * num_rows for _ in range (0 , num_cols )]
316+ glt_full = numpy .full (
317+ (num_cols , num_rows , 2 ), GLT_FILL_VALUE , dtype = numpy .int64
318+ )
389319 logging .info ('GLT will have shape (%d, %d)' , num_cols , num_rows )
390320
391- # Further speed things up by working in parallel.
392- # The "cell-var-from-loop" warnings can be ignored because those vars
393- # are global from this method's point of view.
394- # pylint: disable=cell-var-from-loop
395- def _fill_glt_col (col_idx : int ) -> None :
396- """Populates `glt_i` and 'glt_j` for a single column."""
397- lat = bbox .north + (col_idx * scale_lat ) + (scale_lat / 2 )
398- lons = [
399- bbox .west + (row_idx * scale_lon ) + (scale_lon / 2 )
400- for row_idx in range (0 , num_rows )
401- ]
402-
403- dd , ii = tree .query (
404- [(math .radians (lat ), math .radians (lon )) for lon in lons ], k = 1
321+ # Generate all latitude centers and longitude centers.
322+ lats = bbox .north + numpy .arange (num_cols ) * scale_lat + (scale_lat / 2 )
323+ lons = bbox .west + numpy .arange (num_rows ) * scale_lon + (scale_lon / 2 )
324+
325+ # Define a block size for batch processing.
326+ block_size = 100
327+
328+ def _fill_glt_block (col_start : int ) -> None :
329+ """Populates a block of columns in `glt_full`."""
330+ col_end = min (col_start + block_size , num_cols )
331+ block_lats = lats [col_start :col_end ]
332+
333+ # Optimize XYZ coordinate generation using broadcasting.
334+ # This avoids redundant meshgrid and cos/sin operations on large grids.
335+ rad_lat = numpy .radians (block_lats )
336+ rad_lon = numpy .radians (lons )
337+ cos_lat = numpy .cos (rad_lat )
338+ sin_lat = numpy .sin (rad_lat )
339+ cos_lon = numpy .cos (rad_lon )
340+ sin_lon = numpy .sin (rad_lon )
341+
342+ # Broadcasting to create (N_lats, N_lons, 3) XYZ grid.
343+ x = cos_lat [:, numpy .newaxis ] * cos_lon [numpy .newaxis , :]
344+ y = cos_lat [:, numpy .newaxis ] * sin_lon [numpy .newaxis , :]
345+ z = numpy .repeat (sin_lat [:, numpy .newaxis ], len (lons ), axis = 1 )
346+ query_xyz = numpy .stack ([x , y , z ], axis = - 1 ).reshape (- 1 , 3 )
347+
348+ # Use distance_upper_bound to prune search for points far from any data.
349+ dd , ii = tree .query (query_xyz , k = 1 , distance_upper_bound = max_chord )
350+
351+ # Flattened block view for assignment.
352+ flat_ii = ii .flatten ()
353+ flat_dd = dd .flatten ()
354+ valid = flat_dd <= max_chord
355+
356+ # Fancy indexing into index.source_indices.
357+ block_indices = numpy .full (
358+ (len (flat_ii ), 2 ), GLT_FILL_VALUE , dtype = numpy .int64
359+ )
360+ block_indices [valid ] = index .source_indices [flat_ii [valid ]]
361+
362+ glt_full [col_start :col_end , :, :] = block_indices .reshape (
363+ (col_end - col_start , num_rows , 2 )
405364 )
406- for row_idx , (near_dist , near_idx ) in enumerate (zip (dd , ii )):
407- if near_dist [0 ] > max_distance :
408- continue
409- elif near_idx [0 ] >= len (index .points ):
410- raise ProjectionError ('Bad nearest index {}' .format (near_idx [0 ]))
411- else :
412- orig_point = index .points [near_idx [0 ]]
413- orig_ij = index .point_index .get (orig_point )
414- if orig_ij is None :
415- raise ProjectionError ('Bad nearest point {}' .format (orig_point ))
416- glt_i [col_idx ][row_idx ] = orig_ij [0 ]
417- glt_j [col_idx ][row_idx ] = orig_ij [1 ]
418-
419- # pylint: enable=cell-var-from-loop
420365
421366 with concurrent .futures .ThreadPoolExecutor (
422367 max_workers = num_threads
423368 ) as executor :
424- for col_idx in range (0 , num_cols ):
425- executor .submit (_fill_glt_col , col_idx )
369+ for col_start in range (0 , num_cols , block_size ):
370+ executor .submit (_fill_glt_block , col_start )
426371
427- glt = numpy .stack ((glt_i , glt_j ), axis = - 1 , dtype = numpy .int64 )
428- tables .append (cls (bbox , scale_lat , scale_lon , glt ))
372+ tables .append (cls (bbox , scale_lat , scale_lon , glt_full ))
429373
430374 return tables
431375
0 commit comments