Skip to content

Commit 6713c53

Browse files
committed
Merge PR #402 (Optimize get_diff_of_diffs and add_lumped_species_to_dataset)
This merge brings PR #402 (Optimize get_diff_of_diffs and add_ lumped_species_to_dataset, by @yantosca) into the GCPy 1.7.0 development stream. PR #402 does the following: 1. In function "make_benchmark_conc_plots": - Optimizes the algorithm that computes the list of variables for each benchmark category - Passes truncated datasets to compare_single_level and compare_zonal_mean. 2. In routine "add_lumped_species_to_dataset" - Vectorizes the summing of species - Merges new species into the Dataset in a single operation. 3. In routine "get_diff_of_diffs": - Optimized the algorithm to compute diff-of-diffs. - Aligns cubed-sphere grids before computing diff-of-diffs 4. In routines "create_regridders", "compare_single_level", and "compare_zonal_mean": - Delete regridder objects once we no longer need them - Manually call gc.collect() to force garbage collection. Signed-off-by: Bob Yantosca <yantosca@seas.harvard.edu>
2 parents eaeae9b + d930cdc commit 6713c53

File tree

8 files changed

+200
-131
lines changed

8 files changed

+200
-131
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1919
- Added functions `get_molwt_from_metadata` and `read_species_metadata` to `gcpy/util.py`
2020
- Added function `get_species_database_files` to `gcpy/benchmark/modules/benchmark_utils.py`
2121
- Added constant `SPECIES_DATABASE` to `gcpy/benchmark/modules/benchmark_utils.py`
22+
- Added manual garbage collection in `create_regridders`, `compare_single_level`, and `compare_zonal_mean` functions.
23+
- Added helpful tips to the `gcpy/benchmark/benchmark.slurm.sh` script
2224

2325
### Changed
2426
- Modified criteria for terminating read of log files in `benchmark_scrape_gcclassic_timers.py` to avoid being spoofed by output that is attached by Intel VTune
@@ -44,6 +46,9 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
4446
- Replaced `get_species_database_dir` with `get_species_database_files` in `gcpy/benchmark/modules/benchmark_funcs.py`
4547
- Updated `gcpy/benchmark/modules/benchmark_scrape_gchp_timers.py` to look for GCHP timers in `allPEs.log` if not found in the log file
4648
- Updated routine `make_benchmark_aerosol_tables` to include all dust species in the aerosol burdens table
49+
- Optimized function `get_diff_of_diffs` (in `gcpy/util.py`) for performance
50+
- Optimized function `add_lumped_species_to_dataset` (in `gcpy/benchmark/modules/benchmark_utils.py`) for performance
51+
- Optimized the algorithm to generate `varlist` in `make_benchmark_conc_plots`. Also truncated datasets to only contain varibales in `varlist`.
4752

4853
### Fixed
4954
- Fixed grid area calculation scripts of `grid_area` in `gcpy/gcpy/cstools.py`
@@ -59,7 +64,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
5964
- Removed `.load()` statements from xarray Datasets to improve performance
6065
- Removed `paths:spcdb_dir` YAML tag in benchmark configuration files
6166
- Removed `st_Ox` from `benchmark_categories.yml`; this species is no longer used in TransportTracers simulations
62-
- Removed special data handling for files generated with MAPL versions prior to 1.0.0 in function `get_diff_of_diffs` (located in `gcpy/util.py`)
6367

6468
## [1.6.2] - 2025-06-12
6569
### Added

gcpy/benchmark/benchmark_slurm.sh

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
#SBATCH -c 8
44
#SBATCH -N 1
5-
#SBATCH -t 0-4:00
6-
#SBATCH -p seas_compute,shared
7-
#SBATCH --mem=100000
5+
#SBATCH -t 0-6:00
6+
#SBATCH -p sapphire,huce_cascade,seas_compute,shared
7+
#SBATCH --mem=180000
88
#SBATCH --mail-type=END
99

1010
#============================================================================
@@ -13,7 +13,17 @@
1313
#
1414
# You can modify the SLURM parameters above for your setup.
1515
#
16-
# Tip: Using less cores can reduce the amount of memory required.
16+
# Tips:
17+
# -----
18+
# (1) Use fewer cores to reduce the memory footprint. This may prevent
19+
# your job from running out of memory. Python under Linux seems
20+
# to have an issue where not all memory is released back to the OS.
21+
#
22+
# (2) We recommend that you generate only one benchmark comparison
23+
# (GCC vs GCC, GCHP vs GCC, GCHP vs GCC, or diff of diffs)
24+
# at a time. Otherwise your job will probaly run out of memory.
25+
#
26+
# (3) For diff-of-diffs plots, we recommend using 6 cores.
1727
#============================================================================
1828

1929
# Apply all bash initialization settings

gcpy/benchmark/modules/benchmark_funcs.py

Lines changed: 47 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,7 +1277,7 @@ def make_benchmark_conc_plots(
12771277
devds = add_lumped_species_to_dataset(devds)
12781278

12791279
if diff_of_diffs:
1280-
print("-->Adding lumped species to dev datasets")
1280+
print("-->Adding lumped species to second ref and dev datasets")
12811281
second_refds = add_lumped_species_to_dataset(second_refds)
12821282
second_devds = add_lumped_species_to_dataset(second_devds)
12831283

@@ -1340,19 +1340,34 @@ def createplots(filecat):
13401340
if not os.path.isdir(catdir):
13411341
os.mkdir(catdir)
13421342

1343-
varlist = []
1344-
warninglist = []
1345-
for subcat in catdict[filecat]:
1346-
for spc in catdict[filecat][subcat]:
1347-
varname = coll_prefix + spc
1348-
if varname not in refds.data_vars or \
1349-
varname not in devds.data_vars:
1350-
warninglist.append(varname)
1351-
continue
1352-
varlist.append(varname)
1353-
if warninglist:
1354-
msg = f"\n\nWarning: variables in {filecat} category not in dataset: {warninglist}"
1355-
print(msg)
1343+
# Get the list of variables in both Ref and Dev for each category
1344+
# (this is computationally efficient)
1345+
ref_vars = set(refds.data_vars)
1346+
dev_vars = set(devds.data_vars)
1347+
candidates = [
1348+
coll_prefix + spc
1349+
for subcat in catdict[filecat]
1350+
for spc in catdict[filecat][subcat]
1351+
]
1352+
varlist = \
1353+
[var for var in candidates \
1354+
if var in ref_vars and var in dev_vars
1355+
]
1356+
warninglist = \
1357+
[var for var in candidates \
1358+
if var not in ref_vars or var not in dev_vars
1359+
]
1360+
1361+
# Create new datasets containing only the variables for a
1362+
# given category, as this will optimize performance.
1363+
refds_cat = refds[varlist]
1364+
devds_cat = devds[varlist]
1365+
second_refds_cat = None
1366+
if second_refds is not None:
1367+
second_refds_cat = second_refds[varlist]
1368+
second_devds_cat = None
1369+
if second_devds is not None:
1370+
second_devds_cat = second_devds[varlist]
13561371

13571372
# -----------------------
13581373
# Surface plots
@@ -1373,9 +1388,9 @@ def createplots(filecat):
13731388

13741389
diff_sfc = []
13751390
compare_single_level(
1376-
refds,
1391+
refds_cat,
13771392
refstr,
1378-
devds,
1393+
devds_cat,
13791394
devstr,
13801395
varlist=varlist,
13811396
ilev=0,
@@ -1390,8 +1405,8 @@ def createplots(filecat):
13901405
sigdiff_list=diff_sfc,
13911406
weightsdir=weightsdir,
13921407
convert_to_ugm3=convert_to_ugm3,
1393-
second_ref=second_refds,
1394-
second_dev=second_devds,
1408+
second_ref=second_refds_cat,
1409+
second_dev=second_devds_cat,
13951410
n_job=n_job,
13961411
spcdb_files=spcdb_files,
13971412
)
@@ -1412,7 +1427,8 @@ def createplots(filecat):
14121427

14131428
if subdst is not None:
14141429
pdfname = os.path.join(
1415-
catdir, f"{filecat}_500hPa_{subdst}.pdf"
1430+
catdir,
1431+
f"{filecat}_500hPa_{subdst}.pdf"
14161432
)
14171433
else:
14181434
pdfname = os.path.join(
@@ -1422,9 +1438,9 @@ def createplots(filecat):
14221438

14231439
diff_500 = []
14241440
compare_single_level(
1425-
refds,
1441+
refds_cat,
14261442
refstr,
1427-
devds,
1443+
devds_cat,
14281444
devstr,
14291445
varlist=varlist,
14301446
ilev=22,
@@ -1439,8 +1455,8 @@ def createplots(filecat):
14391455
sigdiff_list=diff_500,
14401456
weightsdir=weightsdir,
14411457
convert_to_ugm3=convert_to_ugm3,
1442-
second_ref=second_refds,
1443-
second_dev=second_devds,
1458+
second_ref=second_refds_cat,
1459+
second_dev=second_devds_cat,
14441460
n_job=n_job,
14451461
spcdb_files=spcdb_files
14461462
)
@@ -1473,9 +1489,9 @@ def createplots(filecat):
14731489

14741490
diff_zm = []
14751491
compare_zonal_mean(
1476-
refds,
1492+
refds_cat,
14771493
refstr,
1478-
devds,
1494+
devds_cat,
14791495
devstr,
14801496
varlist=varlist,
14811497
refmet=refmetds,
@@ -1488,8 +1504,8 @@ def createplots(filecat):
14881504
sigdiff_list=diff_zm,
14891505
weightsdir=weightsdir,
14901506
convert_to_ugm3=convert_to_ugm3,
1491-
second_ref=second_refds,
1492-
second_dev=second_devds,
1507+
second_ref=second_refds_cat,
1508+
second_dev=second_devds_cat,
14931509
n_job=n_job,
14941510
spcdb_files=spcdb_files
14951511
)
@@ -1518,9 +1534,9 @@ def createplots(filecat):
15181534
)
15191535

15201536
compare_zonal_mean(
1521-
refds,
1537+
refds_cat,
15221538
refstr,
1523-
devds,
1539+
devds_cat,
15241540
devstr,
15251541
varlist=varlist,
15261542
refmet=refmetds,
@@ -1534,8 +1550,8 @@ def createplots(filecat):
15341550
normalize_by_area=normalize_by_area,
15351551
convert_to_ugm3=convert_to_ugm3,
15361552
weightsdir=weightsdir,
1537-
second_ref=second_refds,
1538-
second_dev=second_devds,
1553+
second_ref=second_refds_cat,
1554+
second_dev=second_devds_cat,
15391555
n_job=n_job,
15401556
spcdb_files=spcdb_files
15411557
)

gcpy/benchmark/modules/benchmark_utils.py

Lines changed: 61 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,8 @@ def add_lumped_species_to_dataset(
366366
as a dictionary or a path to a yaml file. If neither is passed then
367367
the lumped species yaml file stored in gcpy is used. This file is
368368
customized for use with benchmark simuation SpeciesConc diagnostic
369-
collection output.
369+
collection output. The algorithm has been optimized by AI to
370+
improve performance.
370371
371372
Args
372373
dset : xr.Dataset : Data prior to adding lumped species
@@ -378,11 +379,22 @@ def add_lumped_species_to_dataset(
378379
379380
Returns
380381
dset : xr.Dataset : Original species plus added lumped species
382+
383+
Remarks
384+
-------
385+
Key Improvements:
386+
1. Vectorized summation: Uses sum(to_sum) instead of incremental +=
387+
2. Lazy evaluation: Operations remain lazy until actual computation
388+
3. Single merge: Uses .assign() instead of merging many DataArrays
389+
4. Cleaner logic: More Pythonic dictionary iteration
390+
391+
Performance Impact:
392+
Original: O(n_lumped × n_constituents) individual array operations
393+
Optimized: O(n_lumped) vectorized operations
381394
"""
382395

383-
# Default is to add all benchmark lumped species.
384-
# Can overwrite by passing a dictionary
385-
# or a yaml file path containing one
396+
# Default is to add all benchmark lumped species. Can overwrite
397+
# by passing a dictionary or a yaml file path containing one.
386398
assert not (
387399
lspc_dict is not None and lspc_yaml != ""
388400
), "Cannot pass both lspc_dict and lspc_yaml. Choose one only."
@@ -394,67 +406,64 @@ def add_lumped_species_to_dataset(
394406
# Make sure attributes are transferred when copying dataset / dataarrays
395407
with xr.set_options(keep_attrs=True):
396408

397-
# Get a dummy DataArray to use for initialization
398-
dummy_darr = None
399-
for var in dset.data_vars:
400-
if prefix in var or prefix.replace("VV", "") in var:
401-
dummy_darr = dset[var]
402-
dummy_type = dummy_darr.dtype
403-
dummy_shape = dummy_darr.shape
404-
break
405-
if dummy_darr is None:
406-
msg = "Invalid prefix: " + prefix
407-
raise ValueError(msg)
408-
409-
# Create a list with a copy of the dummy DataArray object
410-
n_lumped_spc = len(lspc_dict)
411-
lumped_spc = [None] * n_lumped_spc
412-
for var, spcname in enumerate(lspc_dict):
413-
lumped_spc[var] = dummy_darr.copy(deep=False)
414-
lumped_spc[var].name = prefix + spcname
415-
lumped_spc[var].values = np.full(dummy_shape, 0.0, dtype=dummy_type)
416-
417-
# Loop over lumped species list
418-
for var, lspc in enumerate(lumped_spc):
419-
420-
# Search key for lspc_dict is lspc.name minus the prefix
421-
cidx = lspc.name.find("_")
422-
key = lspc.name[cidx+1:]
409+
# Dictionary to store new lumped species
410+
new_vars = {}
411+
412+
# Loop over lumped species
413+
for lspc_name, constituents in lspc_dict.items():
414+
full_name = prefix + lspc_name
423415

424416
# Check if overlap with existing species
425-
if lspc.name in dset.data_vars and overwrite:
426-
dset.drop(lspc.name)
427-
else:
428-
assert(lspc.name not in dset.data_vars), \
429-
f"{lspc.name} already in dataset. To overwrite pass overwrite=True."
417+
if full_name in dset.data_vars:
418+
if overwrite:
419+
if verbose:
420+
print(f"Overwriting existing {full_name}")
421+
else:
422+
raise ValueError(
423+
f"{full_name} already in dataset. "
424+
"To overwrite pass overwrite=True."
425+
)
430426

431-
# Verbose prints
432427
if verbose:
433-
print(f"Creating {lspc.name}")
428+
print(f"Creating {full_name}")
434429

435-
# Loop over and sum constituent species values
436-
num_spc = 0
437-
for _, spcname in enumerate(lspc_dict[key]):
430+
# Collect all constituent species that exist
431+
to_sum = []
432+
for spcname, scale_factor in constituents.items():
438433
varname = prefix + spcname
439434
if varname not in dset.data_vars:
440435
if verbose:
441-
print(f"Warning: {varname} needed for {lspc_dict[key][spcname]} not in dataset")
436+
print(f"Warning: {varname} needed for {scale_factor} not in dataset")
442437
continue
438+
443439
if verbose:
444-
print(f" -> adding {varname} with scale {lspc_dict[key][spcname]}")
445-
lspc.values += dset[varname].values * lspc_dict[key][spcname]
446-
num_spc += 1
440+
print(f" -> adding {varname} with scale {scale_factor}")
441+
442+
# Build list of scaled species (lazy operations)
443+
to_sum.append(dset[varname] * scale_factor)
447444

448-
# Replace values with NaN if no species found in dataset
449-
if num_spc == 0:
445+
# Vectorized sum of all constituents at once
446+
if len(to_sum) > 0:
447+
new_vars[full_name] = sum(to_sum)
448+
else:
449+
# Create NaN array matching first species shape
450450
if verbose:
451451
print("No constituent species found! Setting to NaN.")
452-
lspc.values = np.full(lspc.shape, np.nan)
453-
454-
# Insert the DataSet into the list of DataArrays
455-
# so that we can only do the merge operation once
456-
lumped_spc.insert(0, dset)
457-
dset = xr.merge(lumped_spc)
452+
template_var = next(
453+
(var for key, var in dset.data_vars.items()
454+
if prefix in key or prefix.replace("VV", "") in key),
455+
None
456+
)
457+
if template_var is not None:
458+
new_vars[full_name] = template_var.copy(deep=False) * np.nan
459+
460+
# Single merge operation
461+
if overwrite:
462+
dset = dset.drop_vars(
463+
[key for key in new_vars.keys() if key in dset.data_vars],
464+
errors='ignore'
465+
)
466+
dset = dset.assign(new_vars)
458467

459468
return dset
460469

0 commit comments

Comments
 (0)