diff --git a/tests/test_10_client.py b/tests/test_10_client.py index e42f38d..fb787c3 100644 --- a/tests/test_10_client.py +++ b/tests/test_10_client.py @@ -1,9 +1,12 @@ from pathlib import Path +import pytest + from xarray_esgf import Client -def test_download(tmp_path: Path, index_node: str) -> None: +@pytest.mark.parametrize("check_files", [True, False]) +def test_download(tmp_path: Path, index_node: str, check_files: bool) -> None: selection: dict[str, str | list[str]] = { "query": '"tas_Amon_EC-Earth3-CC_ssp245_r1i1p1f1_gr_201901-201912.nc"' } @@ -11,6 +14,7 @@ def test_download(tmp_path: Path, index_node: str) -> None: selection, esgpull_path=str(tmp_path / "esgpull"), index_node=index_node, + check_files=check_files, ) downloaded = client.download() diff --git a/xarray_esgf/client.py b/xarray_esgf/client.py index 10022ae..de39113 100644 --- a/xarray_esgf/client.py +++ b/xarray_esgf/client.py @@ -56,6 +56,7 @@ class Client: esgpull_path: str | Path | None = None index_node: str | None = None retries: int = 0 + check_files: bool = True verify_ssl: bool = False @cached_property @@ -91,9 +92,14 @@ def files(self) -> list[File]: @property def missing_files(self) -> list[File]: - return [ - file for file in self.files if self._client.fs.check(file) != FileCheck.Ok - ] + missing_files = [] + for file in tqdm.tqdm(self.files, desc="Looking for missing files:"): + file_path = Path(str(self._client.fs[file])) + if (self.check_files and self._client.fs.check(file) != FileCheck.Ok) or ( + not self.check_files and not file_path.exists() + ): + missing_files.append(file) + return missing_files def download(self) -> list[File]: files = [] diff --git a/xarray_esgf/engine.py b/xarray_esgf/engine.py index 67864bf..a00d585 100644 --- a/xarray_esgf/engine.py +++ b/xarray_esgf/engine.py @@ -16,17 +16,19 @@ def open_dataset( # type: ignore[override] drop_variables: str | Iterable[str] | None = None, esgpull_path: str | Path | None = None, index_node: str | None = None, + retries: int = 0, + check_files: bool = True, + verify_ssl: bool = False, concat_dims: DATASET_ID_KEYS | Iterable[DATASET_ID_KEYS] | None = None, download: bool = False, show_progress: bool = True, - retries: int = 0, - verify_ssl: bool = False, ) -> Dataset: client = Client( selection=filename_or_obj, esgpull_path=esgpull_path, index_node=index_node, retries=retries, + check_files=check_files, verify_ssl=verify_ssl, ) return client.open_dataset(