diff --git a/esmvalcore/cmor/check.py b/esmvalcore/cmor/check.py index a75dcdaab4..de10d74424 100644 --- a/esmvalcore/cmor/check.py +++ b/esmvalcore/cmor/check.py @@ -985,6 +985,12 @@ def cmor_check_data( check_level=check_level, ) cube = checker(cube).check_data() + # Remove the "source_file" attribute that `esmvalcore.preprocessor.load` + # adds for CMOR fix and check function logging purposes. This is a bit + # ugly and it would be nice to stop using the "source_file" attribute and + # pass the data source as an argument to those functions that need it + # instead. + cube.attributes.pop("source_file", None) return cube diff --git a/esmvalcore/cmor/fix.py b/esmvalcore/cmor/fix.py index ab81353cfb..45d281ef51 100644 --- a/esmvalcore/cmor/fix.py +++ b/esmvalcore/cmor/fix.py @@ -8,7 +8,6 @@ from __future__ import annotations import logging -from collections import defaultdict from collections.abc import Sequence from pathlib import Path from typing import TYPE_CHECKING, Optional @@ -137,7 +136,7 @@ def fix_metadata( Returns ------- iris.cube.CubeList - Fixed cubes. + A list containing a single fixed cube. """ # Update extra_facets with variable information given as regular arguments @@ -161,27 +160,14 @@ def fix_metadata( session=session, frequency=frequency, ) - fixed_cubes = CubeList() - # Group cubes by input file and apply all fixes to each group element - # (i.e., each file) individually - by_file = defaultdict(list) - for cube in cubes: - by_file[cube.attributes.get("source_file", "")].append(cube) + cubes = CubeList(cubes) + for fix in fixes: + cubes = fix.fix_metadata(cubes) - for cube_list in by_file.values(): - cube_list = CubeList(cube_list) - for fix in fixes: - cube_list = fix.fix_metadata(cube_list) - - # The final fix is always GenericFix, whose fix_metadata method always - # returns a single cube - cube = cube_list[0] - - cube.attributes.pop("source_file", None) - fixed_cubes.append(cube) - - return fixed_cubes + # The final fix is always GenericFix, whose fix_metadata method always + # returns a single cube + return CubeList(cubes[:1]) def fix_data( diff --git a/esmvalcore/dataset.py b/esmvalcore/dataset.py index 6717485e38..33dcdbf9cd 100644 --- a/esmvalcore/dataset.py +++ b/esmvalcore/dataset.py @@ -7,12 +7,15 @@ import re import textwrap import uuid +from collections.abc import Iterable from copy import deepcopy from fnmatch import fnmatchcase from itertools import groupby from pathlib import Path -from typing import Any, Iterator, Sequence, Union +from typing import Any, Iterator, Sequence, TypeVar, Union +import dask +from dask.delayed import Delayed from iris.cube import Cube from esmvalcore import esgf, local @@ -84,6 +87,14 @@ def _ismatch(facet_value: FacetValue, pattern: FacetValue) -> bool: ) +T = TypeVar("T") + + +def _first(elems: Iterable[T]) -> T: + """Return the first element.""" + return next(iter(elems)) + + class Dataset: """Define datasets, find the related files, and load them. @@ -693,9 +704,19 @@ def files(self) -> Sequence[File]: def files(self, value): self._files = value - def load(self) -> Cube: + def load(self, compute: bool = True) -> Cube | Delayed: """Load dataset. + Parameters + ---------- + compute: + If :obj:`True`, return the :class:`~iris.cube.Cube` immediately. + If :obj:`False`, return a :class:`~dask.delayed.Delayed` object + that can be used to load the cube by calling its + :meth:`~dask.delayed.Delayed.compute` method. Multiple datasets + can be loaded in parallel by passing a list of such delayeds + to :func:`dask.compute`. + Raises ------ InputFilesNotFound @@ -718,7 +739,7 @@ def load(self) -> Cube: supplementary_cubes.append(supplementary_cube) output_file = _get_output_file(self.facets, self.session.preproc_dir) - cubes = preprocess( + cubes = dask.delayed(preprocess)( [cube], "add_supplementary_variables", input_files=input_files, @@ -727,7 +748,10 @@ def load(self) -> Cube: supplementary_cubes=supplementary_cubes, ) - return cubes[0] + cube = dask.delayed(_first)(cubes) + if compute: + return cube.compute() + return cube def _load(self) -> Cube: """Load self.files into an iris cube and return it.""" @@ -742,7 +766,16 @@ def _load(self) -> Cube: msg = "\n".join(lines) raise InputFilesNotFound(msg) + input_files = [ + file.local_file(self.session["download_dir"]) + if isinstance(file, esgf.ESGFFile) + else file + for file in self.files + ] output_file = _get_output_file(self.facets, self.session.preproc_dir) + debug = self.session["save_intermediary_cubes"] + + # Load all input files and concatenate them. fix_dir_prefix = Path( self.session._fixed_file_dir, self._get_joined_summary_facets("_", join_lists=True) + "_", @@ -765,6 +798,51 @@ def _load(self) -> Cube: **self.facets, } settings["concatenate"] = {"check_level": self.session["check_level"]} + + result = [] + for input_file in input_files: + files = dask.delayed(preprocess)( + [input_file], + "fix_file", + input_files=[input_file], + output_file=output_file, + debug=debug, + **settings["fix_file"], + ) + # Multiple cubes may be present in a file. + cubes = dask.delayed(preprocess)( + files, + "load", + input_files=[input_file], + output_file=output_file, + debug=debug, + **settings["load"], + ) + # Combine the cubes into a single cube per file. + cubes = dask.delayed(preprocess, pure=False)( + cubes, + "fix_metadata", + input_files=[input_file], + output_file=output_file, + debug=debug, + **settings["fix_metadata"], + ) + cube = dask.delayed(_first)(cubes) + result.append(cube) + + # Concatenate the cubes from all files. + result = dask.delayed(preprocess, pure=False)( + result, + "concatenate", + input_files=input_files, + output_file=output_file, + debug=debug, + **settings["concatenate"], + ) + + # At this point `result` is a list containing a single cube. Apply the + # remaining preprocessor functions to this cube. + settings.clear() settings["cmor_check_metadata"] = { "check_level": self.session["check_level"], "cmor_table": self.facets["project"], @@ -777,6 +855,7 @@ def _load(self) -> Cube: "timerange": self.facets["timerange"], } settings["fix_data"] = { + "pure": False, "session": self.session, **self.facets, } @@ -787,24 +866,18 @@ def _load(self) -> Cube: "frequency": self.facets["frequency"], "short_name": self.facets["short_name"], } - - result = [ - file.local_file(self.session["download_dir"]) - if isinstance(file, esgf.ESGFFile) - else file - for file in self.files - ] for step, kwargs in settings.items(): - result = preprocess( + pure = settings.pop("pure", True) + result = dask.delayed(preprocess, pure=pure)( result, step, - input_files=self.files, + input_files=input_files, output_file=output_file, - debug=self.session["save_intermediary_cubes"], + debug=debug, **kwargs, ) - cube = result[0] + cube = dask.delayed(_first)(result) return cube def from_ranges(self) -> list["Dataset"]: diff --git a/tests/integration/dataset/test_dataset.py b/tests/integration/dataset/test_dataset.py index cdc8310ea0..b95222506e 100644 --- a/tests/integration/dataset/test_dataset.py +++ b/tests/integration/dataset/test_dataset.py @@ -3,6 +3,7 @@ import iris.coords import iris.cube import pytest +from dask.delayed import Delayed from esmvalcore.config import CFG from esmvalcore.dataset import Dataset @@ -55,7 +56,8 @@ def example_data(tmp_path, monkeypatch): monkeypatch.setitem(CFG, "output_dir", tmp_path / "output_dir") -def test_load(example_data): +@pytest.mark.parametrize("lazy", [True, False]) +def test_load(example_data, lazy): tas = Dataset( short_name="tas", mip="Amon", @@ -72,7 +74,11 @@ def test_load(example_data): tas.find_files() print(tas.files) - cube = tas.load() - + if lazy: + result = tas.load(compute=False) + assert isinstance(result, Delayed) + cube = result.compute() + else: + cube = tas.load() assert isinstance(cube, iris.cube.Cube) assert cube.cell_measures() diff --git a/tests/unit/test_cmor_api.py b/tests/unit/test_cmor_api.py index cce1fab9d8..66d69215ae 100644 --- a/tests/unit/test_cmor_api.py +++ b/tests/unit/test_cmor_api.py @@ -41,9 +41,7 @@ def test_cmor_check_metadata(mocker): check_level=sentinel.check_level, ) mock_get_cmor_checker.return_value.assert_called_once_with(sentinel.cube) - ( - mock_get_cmor_checker.return_value.return_value.check_metadata.assert_called_once_with() - ) + mock_get_cmor_checker.return_value.return_value.check_metadata.assert_called_once_with() assert cube == sentinel.checked_cube @@ -52,9 +50,6 @@ def test_cmor_check_data(mocker): mock_get_cmor_checker = mocker.patch.object( esmvalcore.cmor.check, "_get_cmor_checker", autospec=True ) - ( - mock_get_cmor_checker.return_value.return_value.check_data.return_value - ) = sentinel.checked_cube cube = cmor_check_data( sentinel.cube, @@ -73,10 +68,11 @@ def test_cmor_check_data(mocker): check_level=sentinel.check_level, ) mock_get_cmor_checker.return_value.assert_called_once_with(sentinel.cube) - ( - mock_get_cmor_checker.return_value.return_value.check_data.assert_called_once_with() + mock_get_cmor_checker.return_value.return_value.check_data.assert_called_once_with() + checked_cube = ( + mock_get_cmor_checker.return_value.return_value.check_data.return_value ) - assert cube == sentinel.checked_cube + assert cube == checked_cube def test_cmor_check(mocker):