Skip to content

Commit 4a6424f

Browse files
dpencoskcopybara-github
authored andcommitted
Improve geocorrect.py performance by vectorizing GLT construction.
This change replaces inefficient Pythonic loops to take advantage of the many optimizations in the numpy library for applying mathematical operations to potentially large multi-dimensional vectors. We do larger chunks of work in parallel and replace the data structure used to fetch the closed points in the projected grid with a more efficient one. Following the instructions in geocorrect_eval.py, we downloaded EMIT_L2A_RFL_001_20260224T170028_2605511_003.nc and evaluated GLT construction time and accuracy by measuring the Haversine distance between projected pixels and their true location taken from the original, unprojected raster. We observed no change in accuracy while achieving a more than 50x speedup: Previous implementation * GLT construction took 463.1493 seconds - Mean: 24.957609132296444 , Median: 25.994672579115598 - Percentile (50, 90, 95, 99): [25.99467258 36.31322475 38.89952704 42.40422599] * Reference GLT (contained in source file) - Mean difference: 33.09322282322614 , Median: 33.20680931214716 - Percentile (50, 90, 95, 99): [33.20680931 53.28072494 56.83186525 62.86063453] Vectorized implementation * GLT construction took 6.6420 seconds - Mean: 24.957609132296444 , Median: 25.994672579115598 - Percentile (50, 90, 95, 99): [25.99467258 36.31322475 38.89952704 42.40422599] * Reference GLT (contained in source file) - Mean: 33.09322282322614 , Median: 33.20680931214716 - Percentile (50, 90, 95, 99): [33.20680931 53.28072494 56.83186525 62.86063453] PiperOrigin-RevId: 876303127
1 parent 2e2979e commit 4a6424f

File tree

3 files changed

+198
-284
lines changed

3 files changed

+198
-284
lines changed

pipelines/geocorrect.py

Lines changed: 138 additions & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def project(f: h5py.File) -> Iterator[str]:
5050

5151
from absl import logging
5252
import numpy
53-
from sklearn import neighbors
53+
from scipy import spatial
5454
from sklearn.metrics import pairwise
5555

5656
from stac import bboxes
@@ -75,94 +75,31 @@ class ProjectionError(Error):
7575
"""Error raised when we fail to project a coordinate into lat-lon."""
7676

7777

78-
@dataclasses.dataclass(frozen=True)
79-
class S1Interval:
80-
"""An S1Interval represents a closed interval on a unit circle.
81-
82-
Properties:
83-
low: Minimum point in degrees.
84-
high: Maximum point in degrees. If high < low, then the interval is
85-
inverted. This can be used to detect antimeridian crossings when the
86-
input points are longitudes.
87-
"""
88-
89-
low: float
90-
high: float
91-
92-
@classmethod
93-
def empty(cls) -> Self:
94-
"""Returns an empty S1Interval."""
95-
return cls(180, -180)
96-
97-
def is_empty(self) -> bool:
98-
"""Returns True if the S1Interval is empty."""
99-
return self.low == 180 and self.high == -180
100-
101-
def contains(self, longitude: float) -> bool:
102-
"""Returns True if the given longitude is contained by the interval."""
103-
if self.low > self.high:
104-
return longitude >= self.low or longitude <= self.high
105-
else:
106-
return longitude >= self.low and longitude <= self.high
107-
108-
@classmethod
109-
def positive_distance(cls, degrees_a: float, degrees_b: float) -> float:
110-
"""Returns the distance between two points in the range [0, 360)."""
111-
diff = degrees_b - degrees_a
112-
if diff >= 0:
113-
return diff
114-
else:
115-
# If b is 180 and a is -180 + epsilon, we'd prefer to return 360.
116-
return (degrees_b + 180) - (degrees_a - 180)
117-
118-
@classmethod
119-
def check(
120-
cls, degrees_low: float, degrees_high: float
121-
) -> tuple[float, float]:
122-
"""Returns points so that the low-high interval is on a unit circle."""
123-
if degrees_low == -180 and degrees_high != 180:
124-
degrees_low = math.pi
125-
if degrees_high == -180 and degrees_low != 180:
126-
degrees_high = math.pi
127-
return degrees_low, degrees_high
128-
129-
def add_longitude(self, lon: float) -> Self:
130-
"""Returns an S1Interval that includes `lon`."""
131-
if math.fabs(lon) > 180:
132-
raise InputError('Cannot add latitude %f to S1Interval'.format(lon))
133-
if lon == -180:
134-
lon = 180
135-
136-
if self.is_empty():
137-
return S1Interval(lon, lon)
138-
elif self.contains(lon):
139-
return self
140-
else:
141-
dist_low = self.positive_distance(lon, self.low)
142-
dist_high = self.positive_distance(self.high, lon)
143-
if dist_low < dist_high:
144-
return S1Interval(*self.check(lon, self.high))
145-
else:
146-
return S1Interval(*self.check(self.low, lon))
78+
def latlon_to_xyz(
79+
lat_deg: Union[float, numpy.ndarray], lon_deg: Union[float, numpy.ndarray]
80+
) -> numpy.ndarray:
81+
"""Converts lat/lon in degrees to Cartesian (x, y, z) on unit sphere."""
82+
lat = numpy.radians(lat_deg)
83+
lon = numpy.radians(lon_deg)
84+
x = numpy.cos(lat) * numpy.cos(lon)
85+
y = numpy.cos(lat) * numpy.sin(lon)
86+
z = numpy.sin(lat)
87+
return numpy.stack((x, y, z), axis=-1)
14788

14889

14990
@dataclasses.dataclass(frozen=True)
15091
class CoordinateIndex:
15192
"""An index of (lat, lon) coordinates to their array offsets in 2D rasters.
15293
15394
Properties:
154-
points: A list of (latitude, longitude) pairs.
155-
point_index: A map from (latitude, longitude) to the (i, j) position in the
156-
original 2D rasters.
157-
bbox_list: The bounding boxes that cover all of `points`. Generally there
158-
will only be one, but there will be two if the region crosses the
159-
antimeridian, one for each side. Note that if 'points' represent pixel
160-
centers, then the bounding boxes will not cover the entire border.
95+
points: A numpy array of (lat, lon) points found in the source rasters.
96+
bbox_list: The bounding boxes that cover all relevant coordinate points.
97+
source_indices: A numpy array of (i, j) offsets into the original rasters.
16198
"""
16299

163-
points: list[tuple[float, float]]
164-
point_index: dict[tuple[float, float], tuple[int, int]]
100+
points: numpy.ndarray
165101
bbox_list: list[bboxes.BBox]
102+
source_indices: numpy.ndarray
166103

167104
@classmethod
168105
def from_arrays(
@@ -211,83 +148,75 @@ def from_arrays(
211148
)
212149
logging.info('Source rasters have shape %s', lat.shape)
213150

214-
# Build a bounding box in each hemisphere. This will matter if we cross the
215-
# antimeridian; we don't want a single box that spans from (-180, 180) with
216-
# a lot of empty pixels in the middle.
217-
west_bbox = None
218-
east_bbox = None
219-
s1_interval = S1Interval.empty()
220-
if mask is None:
221-
mask = numpy.full((lat.shape[0], lat.shape[1]), 1, dtype=numpy.uint8)
222-
223-
points = []
224-
point_index = {}
225-
for i, (lat_col, lon_col, mask_col) in enumerate(zip(lat, lon, mask)):
226-
for j, (lat_ij, lon_ij, mask_ij) in enumerate(
227-
zip(lat_col, lon_col, mask_col)
228-
):
229-
if (
230-
lat_ij == lat_fill_value
231-
or lon_ij == lon_fill_value
232-
# == doesn't work for numpy.nan
233-
or (numpy.isnan(lat_ij) and numpy.isnan(lat_fill_value))
234-
or (numpy.isnan(lon_ij) and numpy.isnan(lon_fill_value))
235-
):
236-
continue
237-
lat_ij = lat_ij.item()
238-
lon_ij = lon_ij.item()
239-
if mask_ij != 0:
240-
points.append((lat_ij, lon_ij))
241-
point_index[(lat_ij, lon_ij)] = (i, j)
242-
s1_interval = s1_interval.add_longitude(lon_ij)
243-
244-
# Update the appropriate bounding box depending on the hemisphere of
245-
# this point.
246-
if lon_ij <= 0:
247-
if west_bbox is None:
248-
west_bbox = bboxes.BBox(lon_ij, lat_ij, lon_ij, lat_ij)
249-
this_bbox = west_bbox
250-
else:
251-
if east_bbox is None:
252-
east_bbox = bboxes.BBox(lon_ij, lat_ij, lon_ij, lat_ij)
253-
this_bbox = east_bbox
254-
if lat_ij < this_bbox.south:
255-
this_bbox.south = lat_ij
256-
if lat_ij > this_bbox.north:
257-
this_bbox.north = lat_ij
258-
if lon_ij < this_bbox.west:
259-
this_bbox.west = lon_ij
260-
if lon_ij > this_bbox.east:
261-
this_bbox.east = lon_ij
151+
# Coords that are not fill values are used for BBox calculation.
152+
coords_valid = (lat != lat_fill_value) & (lon != lon_fill_value)
153+
if numpy.isnan(lat_fill_value):
154+
coords_valid &= ~numpy.isnan(lat)
155+
if numpy.isnan(lon_fill_value):
156+
coords_valid &= ~numpy.isnan(lon)
262157

263-
bbox_list = []
264-
if s1_interval.is_empty():
158+
if not numpy.any(coords_valid):
265159
raise EmptyInputError('The input grid was empty')
266-
elif s1_interval.low > s1_interval.high:
267-
# If the S1Interval is inverted, then we crossed the antimeridian.
268-
if east_bbox is None or west_bbox is None:
269-
raise ProjectionError(
270-
'The antimeridian was crossed but one hemisphere has a null bbox'
271-
)
272-
bbox_list.extend((east_bbox, west_bbox))
160+
161+
valid_lat = lat[coords_valid]
162+
valid_lon = lon[coords_valid]
163+
164+
# Minimal enclosing interval on S1 to detect antimeridian crossing.
165+
lons_to_check = valid_lon.copy()
166+
lons_to_check[lons_to_check == -180] = 180
167+
lons_unique = numpy.unique(lons_to_check)
168+
if lons_unique.size == 1:
169+
s1_low, s1_high = float(lons_unique[0]), float(lons_unique[0])
273170
else:
274-
if east_bbox is not None:
275-
if west_bbox is not None:
276-
# We crossed the prime meridian, which is fine.
277-
bbox_list.append(east_bbox.union(west_bbox))
278-
else:
279-
bbox_list.append(east_bbox)
280-
elif west_bbox is not None:
281-
bbox_list.append(west_bbox)
171+
gaps = numpy.diff(lons_unique)
172+
wrap_gap = 360.0 - (lons_unique[-1] - lons_unique[0])
173+
max_gap_idx = numpy.argmax(gaps)
174+
if wrap_gap >= gaps[max_gap_idx]:
175+
s1_low, s1_high = float(lons_unique[0]), float(lons_unique[-1])
282176
else:
283-
raise EmptyInputError('The input grid was empty')
177+
s1_low, s1_high = (
178+
float(lons_unique[max_gap_idx + 1]),
179+
float(lons_unique[max_gap_idx]),
180+
)
181+
182+
# Points for GLT are also filtered by `mask`.
183+
glt_valid = coords_valid.copy()
184+
if mask is not None:
185+
glt_valid = glt_valid & (mask != 0)
186+
187+
if not numpy.any(glt_valid):
188+
raise EmptyInputError('The input grid had no unmasked points')
189+
190+
idx_i, idx_j = numpy.where(glt_valid)
191+
points = numpy.stack([lat[glt_valid], lon[glt_valid]], axis=-1)
192+
source_indices = numpy.stack([idx_i, idx_j], axis=-1)
193+
194+
# Build a bounding box in each hemisphere.
195+
is_west = valid_lon <= 0
196+
is_east = ~is_west
197+
198+
bbox_list = []
199+
if numpy.any(is_west):
200+
w_lats = valid_lat[is_west]
201+
w_lons = valid_lon[is_west]
202+
bbox_list.append(
203+
bboxes.BBox(w_lons.min(), w_lats.min(), w_lons.max(), w_lats.max())
204+
)
284205

285-
if not points:
286-
raise EmptyInputError(
287-
'No points mapped by CoordinateIndex with bounding boxes: %s',
288-
bbox_list,
206+
if numpy.any(is_east):
207+
e_lats = valid_lat[is_east]
208+
e_lons = valid_lon[is_east]
209+
bbox_list.insert(
210+
0, bboxes.BBox(e_lons.min(), e_lats.min(), e_lons.max(), e_lats.max())
289211
)
290-
return cls(points, point_index, bbox_list)
212+
213+
if s1_low <= s1_high:
214+
# If we didn't cross the antimeridian, union the boxes. In this case, we
215+
# crossed the prime meridian, but that's fine.
216+
if len(bbox_list) == 2:
217+
bbox_list = [bbox_list[0].union(bbox_list[1])]
218+
219+
return cls(points, bbox_list, source_indices)
291220

292221

293222
@dataclasses.dataclass(frozen=True)
@@ -357,17 +286,19 @@ def from_index(
357286

358287
# When filling in the corrected grid, we pick a pixel by finding the nearest
359288
# original point to the given position. We fill any gaps by choosing the
360-
# nearest neighbor as long as it is not too many pixels away.
361-
tree = neighbors.BallTree(
362-
[(math.radians(x), math.radians(y)) for x, y in index.points],
363-
metric='haversine',
364-
leaf_size=10,
365-
)
366-
max_distance = max_nn_distance * max(
289+
# nearest neighbor as long as it is not too many pixels away. We use the
290+
# Euclidean distance because it's faster than Haversine and roughly
291+
# equivalent.
292+
source_xyz = latlon_to_xyz(index.points[:, 0], index.points[:, 1])
293+
tree = spatial.cKDTree(source_xyz, leafsize=10)
294+
295+
# Convert the angular max distance to chord distance.
296+
max_theta = max_nn_distance * max(
367297
pairwise.haversine_distances(
368298
[[0, 0], [math.radians(scale_lat), math.radians(scale_lon)]]
369299
)[0]
370300
)
301+
max_chord = 2 * math.sin(max_theta / 2)
371302

372303
tables = []
373304
for bbox in index.bbox_list:
@@ -379,53 +310,66 @@ def from_index(
379310
north=max(-90, min(90, bbox.north - (scale_lat / 2))),
380311
)
381312

382-
# Preallocate the GLTs so we only have to do assignment below.
383-
# This should speed things up when the size of the grid is large, which
384-
# tends to happen the farther you get from the equator.
313+
# Preallocate the GLT as a single numpy array.
385314
num_cols = int(math.ceil((bbox.south - bbox.north) / scale_lat))
386315
num_rows = int(math.ceil((bbox.east - bbox.west) / scale_lon))
387-
glt_i = [[GLT_FILL_VALUE] * num_rows for _ in range(0, num_cols)]
388-
glt_j = [[GLT_FILL_VALUE] * num_rows for _ in range(0, num_cols)]
316+
glt_full = numpy.full(
317+
(num_cols, num_rows, 2), GLT_FILL_VALUE, dtype=numpy.int64
318+
)
389319
logging.info('GLT will have shape (%d, %d)', num_cols, num_rows)
390320

391-
# Further speed things up by working in parallel.
392-
# The "cell-var-from-loop" warnings can be ignored because those vars
393-
# are global from this method's point of view.
394-
# pylint: disable=cell-var-from-loop
395-
def _fill_glt_col(col_idx: int) -> None:
396-
"""Populates `glt_i` and 'glt_j` for a single column."""
397-
lat = bbox.north + (col_idx * scale_lat) + (scale_lat / 2)
398-
lons = [
399-
bbox.west + (row_idx * scale_lon) + (scale_lon / 2)
400-
for row_idx in range(0, num_rows)
401-
]
402-
403-
dd, ii = tree.query(
404-
[(math.radians(lat), math.radians(lon)) for lon in lons], k=1
321+
# Generate all latitude centers and longitude centers.
322+
lats = bbox.north + numpy.arange(num_cols) * scale_lat + (scale_lat / 2)
323+
lons = bbox.west + numpy.arange(num_rows) * scale_lon + (scale_lon / 2)
324+
325+
# Define a block size for batch processing.
326+
block_size = 100
327+
328+
def _fill_glt_block(col_start: int) -> None:
329+
"""Populates a block of columns in `glt_full`."""
330+
col_end = min(col_start + block_size, num_cols)
331+
block_lats = lats[col_start:col_end]
332+
333+
# Optimize XYZ coordinate generation using broadcasting.
334+
# This avoids redundant meshgrid and cos/sin operations on large grids.
335+
rad_lat = numpy.radians(block_lats)
336+
rad_lon = numpy.radians(lons)
337+
cos_lat = numpy.cos(rad_lat)
338+
sin_lat = numpy.sin(rad_lat)
339+
cos_lon = numpy.cos(rad_lon)
340+
sin_lon = numpy.sin(rad_lon)
341+
342+
# Broadcasting to create (N_lats, N_lons, 3) XYZ grid.
343+
x = cos_lat[:, numpy.newaxis] * cos_lon[numpy.newaxis, :]
344+
y = cos_lat[:, numpy.newaxis] * sin_lon[numpy.newaxis, :]
345+
z = numpy.repeat(sin_lat[:, numpy.newaxis], len(lons), axis=1)
346+
query_xyz = numpy.stack([x, y, z], axis=-1).reshape(-1, 3)
347+
348+
# Use distance_upper_bound to prune search for points far from any data.
349+
dd, ii = tree.query(query_xyz, k=1, distance_upper_bound=max_chord)
350+
351+
# Flattened block view for assignment.
352+
flat_ii = ii.flatten()
353+
flat_dd = dd.flatten()
354+
valid = flat_dd <= max_chord
355+
356+
# Fancy indexing into index.source_indices.
357+
block_indices = numpy.full(
358+
(len(flat_ii), 2), GLT_FILL_VALUE, dtype=numpy.int64
359+
)
360+
block_indices[valid] = index.source_indices[flat_ii[valid]]
361+
362+
glt_full[col_start:col_end, :, :] = block_indices.reshape(
363+
(col_end - col_start, num_rows, 2)
405364
)
406-
for row_idx, (near_dist, near_idx) in enumerate(zip(dd, ii)):
407-
if near_dist[0] > max_distance:
408-
continue
409-
elif near_idx[0] >= len(index.points):
410-
raise ProjectionError('Bad nearest index {}'.format(near_idx[0]))
411-
else:
412-
orig_point = index.points[near_idx[0]]
413-
orig_ij = index.point_index.get(orig_point)
414-
if orig_ij is None:
415-
raise ProjectionError('Bad nearest point {}'.format(orig_point))
416-
glt_i[col_idx][row_idx] = orig_ij[0]
417-
glt_j[col_idx][row_idx] = orig_ij[1]
418-
419-
# pylint: enable=cell-var-from-loop
420365

421366
with concurrent.futures.ThreadPoolExecutor(
422367
max_workers=num_threads
423368
) as executor:
424-
for col_idx in range(0, num_cols):
425-
executor.submit(_fill_glt_col, col_idx)
369+
for col_start in range(0, num_cols, block_size):
370+
executor.submit(_fill_glt_block, col_start)
426371

427-
glt = numpy.stack((glt_i, glt_j), axis=-1, dtype=numpy.int64)
428-
tables.append(cls(bbox, scale_lat, scale_lon, glt))
372+
tables.append(cls(bbox, scale_lat, scale_lon, glt_full))
429373

430374
return tables
431375

0 commit comments

Comments
 (0)