Skip to content

Commit d901f03

Browse files
authored
Optimise write (#132)
* solve memory exceeded in write * solve memory exceeded in write * update pyproject.toml to solve CI issue
1 parent 207f119 commit d901f03

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dependencies = [
2525
"netCDF4",
2626
"cftime",
2727
"dask",
28+
"distributed>=2024.0.0",
2829
"pyyaml",
2930
"tqdm",
3031
"requests",

src/access_moppy/base.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from typing import Any, Dict, List, Optional, Union
55

66
import netCDF4 as nc
7+
import psutil
78
import xarray as xr
89
from cftime import num2date
10+
from dask.distributed import get_client
911

1012
from access_moppy.utilities import (
1113
FrequencyMismatchError,
@@ -275,6 +277,89 @@ def write(self):
275277
f"Missing required CMIP6 global attributes for filename: {missing}"
276278
)
277279

280+
# ========== Memory Check ==========
281+
# This section estimates the data size and compares it against available memory
282+
# to prevent out-of-memory errors during the write operation.
283+
284+
def estimate_data_size(ds, cmor_name):
285+
total_size = 0
286+
for var in ds.variables:
287+
vdat = ds[var]
288+
# Start with the size of a single element (e.g., 4 bytes for float32)
289+
var_size = vdat.dtype.itemsize
290+
# Multiply by the size of each dimension to get total elements
291+
for dim in vdat.dims:
292+
var_size *= ds.sizes[dim]
293+
total_size += var_size
294+
# Apply 1.5x overhead factor for safe memory estimation
295+
return int(total_size * 1.5)
296+
297+
# Calculate the estimated data size for this dataset
298+
data_size = estimate_data_size(self.ds, self.cmor_name)
299+
300+
# Get system memory information using psutil
301+
available_memory = psutil.virtual_memory().available
302+
303+
# ========== Dask Client Detection ==========
304+
# Check if a Dask distributed client exists, as this affects how we handle
305+
# memory management. Dask clusters have their own memory limits separate
306+
# from system memory.
307+
308+
client = None
309+
worker_memory = None # Memory limit of a single worker
310+
total_cluster_memory = None # Sum of all workers' memory limits
311+
312+
try:
313+
# Attempt to get an existing Dask client
314+
client = get_client()
315+
316+
# Retrieve information about all workers in the cluster
317+
worker_info = client.scheduler_info()["workers"]
318+
319+
if worker_info:
320+
# Get the minimum memory_limit across all workers
321+
worker_memory = min(w["memory_limit"] for w in worker_info.values())
322+
323+
# Sum up all workers' memory for total cluster capacity
324+
total_cluster_memory = sum(
325+
w["memory_limit"] for w in worker_info.values()
326+
)
327+
328+
except ValueError:
329+
# No Dask client exists - we'll use local/system memory for writing
330+
pass
331+
332+
# ========== Memory Validation Logic ==========
333+
# This section implements a decision tree based on data size vs available memory:
334+
335+
if client is not None:
336+
# Dask client exists - check against cluster memory limits
337+
if data_size > worker_memory:
338+
# WARNING: Data fits in total cluster memory but exceeds single worker capacity
339+
print(
340+
f"Warning: Data size ({data_size / 1024**3:.2f} GB) exceeds single worker memory "
341+
f"({worker_memory / 1024**3:.2f} GB) but fits in total cluster memory "
342+
f"({total_cluster_memory / 1024**3:.2f} GB)."
343+
)
344+
print("Closing Dask client to use local memory for writing...")
345+
client.close()
346+
client = None
347+
348+
# If data < worker_memory: No action needed, proceed with write
349+
350+
if data_size > available_memory:
351+
# Data exceeds available system memory
352+
raise MemoryError(
353+
f"Data size ({data_size / 1024**3:.2f} GB) exceeds available system memory "
354+
f"({available_memory / 1024**3:.2f} GB). "
355+
f"Consider using write_parallel() for chunked writing."
356+
)
357+
358+
# Log the memory status for user awareness
359+
print(
360+
f"Data size: {data_size / 1024**3:.2f} GB, Available memory: {available_memory / 1024**3:.2f} GB"
361+
)
362+
278363
time_var = self.ds[self.cmor_name].coords["time"]
279364
units = time_var.attrs["units"]
280365
calendar = time_var.attrs.get("calendar", "standard").lower()

0 commit comments

Comments
 (0)