|
4 | 4 | from typing import Any, Dict, List, Optional, Union |
5 | 5 |
|
6 | 6 | import netCDF4 as nc |
| 7 | +import psutil |
7 | 8 | import xarray as xr |
8 | 9 | from cftime import num2date |
| 10 | +from dask.distributed import get_client |
9 | 11 |
|
10 | 12 | from access_moppy.utilities import ( |
11 | 13 | FrequencyMismatchError, |
@@ -275,6 +277,89 @@ def write(self): |
275 | 277 | f"Missing required CMIP6 global attributes for filename: {missing}" |
276 | 278 | ) |
277 | 279 |
|
| 280 | + # ========== Memory Check ========== |
| 281 | + # This section estimates the data size and compares it against available memory |
| 282 | + # to prevent out-of-memory errors during the write operation. |
| 283 | + |
| 284 | + def estimate_data_size(ds, cmor_name): |
| 285 | + total_size = 0 |
| 286 | + for var in ds.variables: |
| 287 | + vdat = ds[var] |
| 288 | + # Start with the size of a single element (e.g., 4 bytes for float32) |
| 289 | + var_size = vdat.dtype.itemsize |
| 290 | + # Multiply by the size of each dimension to get total elements |
| 291 | + for dim in vdat.dims: |
| 292 | + var_size *= ds.sizes[dim] |
| 293 | + total_size += var_size |
| 294 | + # Apply 1.5x overhead factor for safe memory estimation |
| 295 | + return int(total_size * 1.5) |
| 296 | + |
| 297 | + # Calculate the estimated data size for this dataset |
| 298 | + data_size = estimate_data_size(self.ds, self.cmor_name) |
| 299 | + |
| 300 | + # Get system memory information using psutil |
| 301 | + available_memory = psutil.virtual_memory().available |
| 302 | + |
| 303 | + # ========== Dask Client Detection ========== |
| 304 | + # Check if a Dask distributed client exists, as this affects how we handle |
| 305 | + # memory management. Dask clusters have their own memory limits separate |
| 306 | + # from system memory. |
| 307 | + |
| 308 | + client = None |
| 309 | + worker_memory = None # Memory limit of a single worker |
| 310 | + total_cluster_memory = None # Sum of all workers' memory limits |
| 311 | + |
| 312 | + try: |
| 313 | + # Attempt to get an existing Dask client |
| 314 | + client = get_client() |
| 315 | + |
| 316 | + # Retrieve information about all workers in the cluster |
| 317 | + worker_info = client.scheduler_info()["workers"] |
| 318 | + |
| 319 | + if worker_info: |
| 320 | + # Get the minimum memory_limit across all workers |
| 321 | + worker_memory = min(w["memory_limit"] for w in worker_info.values()) |
| 322 | + |
| 323 | + # Sum up all workers' memory for total cluster capacity |
| 324 | + total_cluster_memory = sum( |
| 325 | + w["memory_limit"] for w in worker_info.values() |
| 326 | + ) |
| 327 | + |
| 328 | + except ValueError: |
| 329 | + # No Dask client exists - we'll use local/system memory for writing |
| 330 | + pass |
| 331 | + |
| 332 | + # ========== Memory Validation Logic ========== |
| 333 | + # This section implements a decision tree based on data size vs available memory: |
| 334 | + |
| 335 | + if client is not None: |
| 336 | + # Dask client exists - check against cluster memory limits |
| 337 | + if data_size > worker_memory: |
| 338 | + # WARNING: Data fits in total cluster memory but exceeds single worker capacity |
| 339 | + print( |
| 340 | + f"Warning: Data size ({data_size / 1024**3:.2f} GB) exceeds single worker memory " |
| 341 | + f"({worker_memory / 1024**3:.2f} GB) but fits in total cluster memory " |
| 342 | + f"({total_cluster_memory / 1024**3:.2f} GB)." |
| 343 | + ) |
| 344 | + print("Closing Dask client to use local memory for writing...") |
| 345 | + client.close() |
| 346 | + client = None |
| 347 | + |
| 348 | + # If data < worker_memory: No action needed, proceed with write |
| 349 | + |
| 350 | + if data_size > available_memory: |
| 351 | + # Data exceeds available system memory |
| 352 | + raise MemoryError( |
| 353 | + f"Data size ({data_size / 1024**3:.2f} GB) exceeds available system memory " |
| 354 | + f"({available_memory / 1024**3:.2f} GB). " |
| 355 | + f"Consider using write_parallel() for chunked writing." |
| 356 | + ) |
| 357 | + |
| 358 | + # Log the memory status for user awareness |
| 359 | + print( |
| 360 | + f"Data size: {data_size / 1024**3:.2f} GB, Available memory: {available_memory / 1024**3:.2f} GB" |
| 361 | + ) |
| 362 | + |
278 | 363 | time_var = self.ds[self.cmor_name].coords["time"] |
279 | 364 | units = time_var.attrs["units"] |
280 | 365 | calendar = time_var.attrs.get("calendar", "standard").lower() |
|
0 commit comments