diff --git a/esmvalcore/cmor/_fixes/fix.py b/esmvalcore/cmor/_fixes/fix.py index 4eb1f69c06..d535a9b010 100644 --- a/esmvalcore/cmor/_fixes/fix.py +++ b/esmvalcore/cmor/_fixes/fix.py @@ -42,6 +42,16 @@ class Fix: """Base class for dataset fixes.""" + GROUP_CUBES_BY_DATE = False + """Flag for grouping cubes for fix_metadata. + + Fixes are applied to each group element individually. + + If ``False`` (default), group cubes by file. If ``True``, group cubes by + date. + + """ + def __init__( self, vardef: VariableInfo, diff --git a/esmvalcore/cmor/_fixes/native6/era5.py b/esmvalcore/cmor/_fixes/native6/era5.py index 746f63b589..09bf4bc321 100644 --- a/esmvalcore/cmor/_fixes/native6/era5.py +++ b/esmvalcore/cmor/_fixes/native6/era5.py @@ -414,6 +414,44 @@ def fix_metadata(self, cubes): return cubes +class Rsut(Fix): + """Fixes for rsut.""" + + # Enable grouping cubes by date for fix_metadata since multiple variables + # from multiple files are needed + GROUP_CUBES_BY_DATE = True + + def fix_metadata(self, cubes): + """Fix metadata. + + Derive rsut as + + rsut = rsdt - rsnt + + with + + rsut = TOA Outgoing Shortwave Radiation + rsdt = TOA Incoming Shortwave Radiation + rsnt = TOA Net Incoming Shortwave Radiation + + """ + rsdt_cube = cubes.extract_cube( + iris.NameConstraint(long_name="TOA incident solar radiation"), + ) + rsnt_cube = cubes.extract_cube( + iris.NameConstraint( + long_name="Mean top net short-wave radiation flux", + ), + ) + rsdt_cube = Rsdt(None).fix_metadata([rsdt_cube])[0] + rsdt_cube.convert_units(self.vardef.units) + + rsdt_cube.data = rsdt_cube.core_data() - rsnt_cube.core_data() + rsdt_cube.attributes["positive"] = "up" + + return iris.cube.CubeList([rsdt_cube]) + + class Rss(Fix): """Fixes for Rss.""" @@ -497,10 +535,7 @@ def fix_metadata(self, cubes): class AllVars(Fix): """Fixes for all variables.""" - def _fix_coordinates( # noqa: C901 - self, - cube, - ): + def _fix_coordinates(self, cube): # noqa: C901 """Fix coordinates.""" # Add scalar height coordinates if "height2m" in self.vardef.dimensions: diff --git a/esmvalcore/cmor/fix.py b/esmvalcore/cmor/fix.py index 2e327ed9e7..7a8cad1b1d 100644 --- a/esmvalcore/cmor/fix.py +++ b/esmvalcore/cmor/fix.py @@ -9,14 +9,15 @@ import logging from collections import defaultdict -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from iris.cube import Cube, CubeList from esmvalcore.cmor._fixes.fix import Fix +from esmvalcore.local import _get_start_end_date if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Iterable, Sequence from pathlib import Path from esmvalcore.config import Session @@ -102,6 +103,27 @@ def fix_file( # noqa: PLR0913 return file +def _group_cubes(fixes: Iterable[Fix], cubes: CubeList) -> dict[Any, CubeList]: + """Group cubes for fix_metadata; each group is processed individually.""" + grouped_cubes: dict[Any, CubeList] = defaultdict(CubeList) + + # Group by date + if any(fix.GROUP_CUBES_BY_DATE for fix in fixes): + for cube in cubes: + if "source_file" in cube.attributes: + dates = _get_start_end_date(cube.attributes["source_file"]) + else: + dates = None + grouped_cubes[dates].append(cube) + + # Group by file name + else: + for cube in cubes: + grouped_cubes[cube.attributes.get("source_file", "")].append(cube) + + return grouped_cubes + + def fix_metadata( cubes: Sequence[Cube], short_name: str, @@ -166,14 +188,14 @@ def fix_metadata( ) fixed_cubes = CubeList() - # Group cubes by input file and apply all fixes to each group element - # (i.e., each file) individually - by_file = defaultdict(list) - for cube in cubes: - by_file[cube.attributes.get("source_file", "")].append(cube) - - for cube_list in by_file.values(): - cube_list = CubeList(cube_list) + # Group cubes and apply all fixes to each group element individually. There + # are two options for grouping: + # (1) By input file name (default). + # (2) By time range (can be enabled by setting the attribute + # GROUP_CUBES_BY_DATE=True for the fix class; see + # _fixes.native6.era5.Rsut for an example). + grouped_cubes = _group_cubes(fixes, cubes) + for cube_list in grouped_cubes.values(): for fix in fixes: cube_list = fix.fix_metadata(cube_list) diff --git a/tests/unit/cmor/test_fix.py b/tests/unit/cmor/test_fix.py index 7b5fc3d0ba..e8dfff48dd 100644 --- a/tests/unit/cmor/test_fix.py +++ b/tests/unit/cmor/test_fix.py @@ -120,6 +120,7 @@ def setUp(self): self.cube = self._create_mock_cube() self.fixed_cube = self._create_mock_cube() self.mock_fix = Mock() + self.mock_fix.GROUP_CUBES_BY_DATE = False self.mock_fix.fix_metadata.return_value = [self.fixed_cube] self.expected_get_fixes_call = { "project": "project",