Skip to content

Commit fe4443a

Browse files
committed
Improve peak memory usage in CF grid conventions
1 parent 807a3ee commit fe4443a

File tree

4 files changed

+53
-50
lines changed

4 files changed

+53
-50
lines changed

src/emsarray/conventions/grid.py

Lines changed: 42 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -413,39 +413,32 @@ def _make_polygons(self) -> numpy.ndarray:
413413
lon_bounds = self.topology.longitude_bounds.values
414414
lat_bounds = self.topology.latitude_bounds.values
415415

416-
# Make a bounds array as if this dataset had 2D coordinates.
417-
# 1D bounds are (lon, 2) and (lat, 2).
418-
# 2D bounds are (lat, lon, 4)
419-
# where the 4 points are (j-i, i-1), (j-1, i+1), (j+1, i+1), (j+1, i-1).
420-
# The bounds values are repeated as required, are given a new dimension,
421-
# then repeated along that new dimension.
422-
# They will come out as array with shape (y_size, x_size, 4)
423-
424-
lon_bounds_2d = numpy.stack([
425-
lon_bounds[:, 0],
426-
lon_bounds[:, 1],
427-
lon_bounds[:, 1],
428-
lon_bounds[:, 0],
429-
], axis=-1)
430-
lon_bounds_2d = numpy.broadcast_to(numpy.expand_dims(lon_bounds_2d, 0), (y_size, x_size, 4))
431-
432-
lat_bounds_2d = numpy.stack([
433-
lat_bounds[:, 0],
434-
lat_bounds[:, 0],
435-
lat_bounds[:, 1],
436-
lat_bounds[:, 1],
437-
], axis=-1)
438-
lat_bounds_2d = numpy.broadcast_to(numpy.expand_dims(lat_bounds_2d, 0), (x_size, y_size, 4))
439-
lat_bounds_2d = numpy.transpose(lat_bounds_2d, (1, 0, 2))
440-
441-
assert lon_bounds_2d.shape == lat_bounds_2d.shape == (y_size, x_size, 4)
442-
443-
# points is a (topology.size, 4, 2) array of the corners of each cell
444-
points = numpy.stack([lon_bounds_2d, lat_bounds_2d], axis=-1).reshape((-1, 4, 2))
445-
446-
polygons = utils.make_polygons_with_holes(points)
447-
448-
return polygons
416+
# Create the polygons batched by row.
417+
# The point array is copied by shapely before being used,
418+
# so this can accidentally use a whole bunch of memory for large datasets.
419+
# Creating them one by one is very slow but very memory efficient.
420+
# Creating the polygons in one batch is faster but uses up a huge amount of memory.
421+
# Batching them row by row is a decent compromise.
422+
out = numpy.full(shape=y_size * x_size, dtype=object, fill_value=None)
423+
424+
# By preallocating this array, we can copy data in to it to save on a number of allocations.
425+
chunk_points = numpy.empty(shape=(x_size, 4, 2), dtype=lon_bounds.dtype)
426+
# By chunking by row, the longitude bounds never change between loops
427+
chunk_points[:, 0, 0] = lon_bounds[:, 0]
428+
chunk_points[:, 1, 0] = lon_bounds[:, 1]
429+
chunk_points[:, 2, 0] = lon_bounds[:, 1]
430+
chunk_points[:, 3, 0] = lon_bounds[:, 0]
431+
432+
for row in range(y_size):
433+
chunk_points[:, 0, 1] = lat_bounds[row, 0]
434+
chunk_points[:, 1, 1] = lat_bounds[row, 0]
435+
chunk_points[:, 2, 1] = lat_bounds[row, 1]
436+
chunk_points[:, 3, 1] = lat_bounds[row, 1]
437+
438+
row_slice = slice(row * x_size, (row + 1) * x_size)
439+
utils.make_polygons_with_holes(chunk_points, out=out[row_slice])
440+
441+
return out
449442

450443
@cached_property
451444
def face_centres(self) -> numpy.ndarray:
@@ -597,13 +590,22 @@ def check_dataset(cls, dataset: xarray.Dataset) -> int | None:
597590

598591
def _make_polygons(self) -> numpy.ndarray:
599592
# Construct polygons from the bounds of the cells
600-
lon_bounds = self.topology.longitude_bounds.values
601-
lat_bounds = self.topology.latitude_bounds.values
602-
603-
# points is a (topology.size, 4, 2) array of the corners of each cell
604-
points = numpy.stack([lon_bounds, lat_bounds], axis=-1).reshape((-1, 4, 2))
605-
606-
return utils.make_polygons_with_holes(points)
593+
j_size, i_size = self.topology.shape
594+
lon_bounds = self.topology.longitude_bounds
595+
lat_bounds = self.topology.latitude_bounds
596+
597+
assert lon_bounds.shape == (j_size, i_size, 4)
598+
assert lat_bounds.shape == (j_size, i_size, 4)
599+
600+
chunk_points = numpy.empty(shape=(i_size, 4, 2), dtype=lon_bounds.dtype)
601+
out = numpy.full(shape=j_size * i_size, dtype=object, fill_value=None)
602+
for j in range(j_size):
603+
chunk_points[:, :, 0] = lon_bounds[j, :, :]
604+
chunk_points[:, :, 1] = lat_bounds[j, :, :]
605+
chunk_slice = slice(j * i_size, (j + 1) * i_size)
606+
utils.make_polygons_with_holes(chunk_points, out=out[chunk_slice])
607+
608+
return out
607609

608610
@cached_property
609611
def face_centres(self) -> numpy.ndarray:

tests/conventions/test_cfgrid1d.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
2-
import pathlib
32
import logging
3+
import pathlib
44

55
import numpy
66
import pandas
@@ -16,7 +16,9 @@
1616
CFGrid1D, CFGrid1DTopology, CFGridKind, CFGridTopology
1717
)
1818
from emsarray.operations import geometry
19-
from tests.utils import assert_property_not_cached, box, mask_from_strings, track_peak_memory_usage
19+
from tests.utils import (
20+
assert_property_not_cached, box, mask_from_strings, track_peak_memory_usage
21+
)
2022

2123
logger = logging.getLogger(__name__)
2224

@@ -512,8 +514,8 @@ def test_make_polygon_memory_usage() -> None:
512514
with track_peak_memory_usage() as tracker:
513515
assert len(dataset.ems.polygons) == width * height
514516

515-
logger.info(f"current memory usage: %d, peak memory usage: %d", tracker.current, tracker.peak)
517+
logger.info("current memory usage: %d, peak memory usage: %d", tracker.current, tracker.peak)
516518

517-
target = 537_000_000
519+
target = 135_000_000
518520
assert tracker.peak < target, "Peak memory allocation is too large"
519521
assert tracker.peak > target * 0.9, "Peak memory allocation is suspiciously small - did you improve things?"

tests/conventions/test_cfgrid2d.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
"""
88
import itertools
99
import json
10-
import pathlib
1110
import logging
11+
import pathlib
1212

1313
import numpy
1414
import pandas
@@ -26,7 +26,8 @@
2626
from emsarray.operations import geometry
2727
from tests.utils import (
2828
AxisAlignedShocGrid, DiagonalShocGrid, ShocGridGenerator,
29-
ShocLayerGenerator, assert_property_not_cached, plot_geometry, track_peak_memory_usage,
29+
ShocLayerGenerator, assert_property_not_cached, plot_geometry,
30+
track_peak_memory_usage
3031
)
3132

3233
logger = logging.getLogger(__name__)
@@ -509,8 +510,8 @@ def test_make_polygon_memory_usage() -> None:
509510
with track_peak_memory_usage() as tracker:
510511
assert len(dataset.ems.polygons) == j_size * i_size
511512

512-
logger.info(f"current memory usage: %d, peak memory usage: %d", tracker.current, tracker.peak)
513+
logger.info("current memory usage: %d, peak memory usage: %d", tracker.current, tracker.peak)
513514

514-
target = 665_000_000
515+
target = 300_000_000
515516
assert tracker.peak < target, "Peak memory allocation is too large"
516517
assert tracker.peak > target * 0.9, "Peak memory allocation is suspiciously small - did you improve things?"

tests/utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import abc
22
import contextlib
3-
import dataclasses
43
import importlib.metadata
54
import itertools
65
import tracemalloc
@@ -467,7 +466,6 @@ def plot_geometry(
467466
figure.savefig(out)
468467

469468

470-
471469
class TracemallocTracker:
472470
_finished = False
473471
_usage = None

0 commit comments

Comments
 (0)