Skip to content

Commit 0c78564

Browse files
committed
clean
1 parent abd784d commit 0c78564

File tree

2 files changed

+67
-8
lines changed

2 files changed

+67
-8
lines changed

src/anemoi/datasets/create/__init__.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,8 @@ def check_name(
256256
resolution: str,
257257
dates: list[datetime.datetime],
258258
frequency: datetime.timedelta,
259-
raise_exception: bool = False,
259+
raise_exception: bool = True,
260+
is_test: bool = False,
260261
) -> None:
261262
"""Check the name of the dataset.
262263
@@ -270,13 +271,15 @@ def check_name(
270271
The frequency of the dataset.
271272
raise_exception : bool, optional
272273
Whether to raise an exception if the name is invalid.
274+
is_test : bool, optional
275+
Whether this is a test.
273276
"""
274277
basename, _ = os.path.splitext(os.path.basename(self.path))
275278
try:
276279
DatasetName(basename, resolution, dates[0], dates[-1], frequency).raise_if_not_valid()
277280
except Exception as e:
278-
if raise_exception:
279-
raise
281+
if raise_exception and not is_test:
282+
raise e
280283
else:
281284
LOG.warning(f"Dataset name error: {e}")
282285

@@ -574,6 +577,7 @@ def __init__(
574577
use_threads: bool = False,
575578
statistics_temp_dir: str | None = None,
576579
progress: Any = None,
580+
test: bool = False,
577581
cache: str | None = None,
578582
**kwargs: Any,
579583
):
@@ -595,6 +599,8 @@ def __init__(
595599
The directory for temporary statistics.
596600
progress : Any, optional
597601
The progress indicator.
602+
test : bool, optional
603+
Whether this is a test.
598604
cache : Optional[str], optional
599605
The cache directory.
600606
"""
@@ -607,8 +613,9 @@ def __init__(
607613
self.use_threads = use_threads
608614
self.statistics_temp_dir = statistics_temp_dir
609615
self.progress = progress
616+
self.test = test
610617

611-
self.main_config = loader_config(config)
618+
self.main_config = loader_config(config, is_test=test)
612619

613620
# self.registry.delete() ??
614621
self.tmp_statistics.delete()
@@ -741,6 +748,7 @@ def _run(self) -> int:
741748

742749
self.dataset.check_name(
743750
raise_exception=self.check_name,
751+
is_test=self.test,
744752
resolution=resolution,
745753
dates=dates,
746754
frequency=frequency,

src/anemoi/datasets/create/config.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from anemoi.utils.config import load_any_dict_format
1919
from earthkit.data.core.order import normalize_order_by
2020

21+
from anemoi.datasets.dates.groups import Groups
22+
2123
LOG = logging.getLogger(__name__)
2224

2325

@@ -338,20 +340,72 @@ def _prepare_serialisation(o: Any) -> Any:
338340
return str(o)
339341

340342

341-
def loader_config(config: dict) -> LoadersConfig:
343+
def set_to_test_mode(cfg: dict) -> None:
344+
"""Modifies the configuration to run in test mode.
345+
346+
Parameters
347+
----------
348+
cfg : dict
349+
The configuration dictionary.
350+
"""
351+
NUMBER_OF_DATES = 4
352+
353+
LOG.warning(f"Running in test mode. Changing the list of dates to use only {NUMBER_OF_DATES}.")
354+
groups = Groups(**LoadersConfig(cfg).dates)
355+
356+
dates = groups.provider.values
357+
cfg["dates"] = dict(
358+
start=dates[0],
359+
end=dates[NUMBER_OF_DATES - 1],
360+
frequency=groups.provider.frequency,
361+
group_by=NUMBER_OF_DATES,
362+
)
363+
364+
num_ensembles = count_ensembles(cfg)
365+
366+
def set_element_to_test(obj):
367+
if isinstance(obj, (list, tuple)):
368+
for v in obj:
369+
set_element_to_test(v)
370+
return
371+
if isinstance(obj, (dict, DotDict)):
372+
if "grid" in obj and num_ensembles > 1:
373+
previous = obj["grid"]
374+
obj["grid"] = "20./20."
375+
LOG.warning(f"Running in test mode. Setting grid to {obj['grid']} instead of {previous}")
376+
if "number" in obj and num_ensembles > 1:
377+
if isinstance(obj["number"], (list, tuple)):
378+
previous = obj["number"]
379+
obj["number"] = previous[0:3]
380+
LOG.warning(f"Running in test mode. Setting number to {obj['number']} instead of {previous}")
381+
for k, v in obj.items():
382+
set_element_to_test(v)
383+
if "constants" in obj:
384+
constants = obj["constants"]
385+
if "param" in constants and isinstance(constants["param"], list):
386+
constants["param"] = ["cos_latitude"]
387+
388+
set_element_to_test(cfg)
389+
390+
391+
def loader_config(config: dict, is_test: bool = False) -> LoadersConfig:
342392
"""Loads and validates the configuration for dataset loaders.
343393
344394
Parameters
345395
----------
346396
config : dict
347397
The configuration dictionary.
398+
is_test : bool, optional
399+
Whether to run in test mode. Defaults to False.
348400
349401
Returns
350402
-------
351403
LoadersConfig
352404
The validated configuration object.
353405
"""
354406
config = Config(config)
407+
if is_test:
408+
set_to_test_mode(config)
355409
obj = LoadersConfig(config)
356410

357411
# yaml round trip to check that serialisation works as expected
@@ -372,9 +426,6 @@ def loader_config(config: dict) -> LoadersConfig:
372426
LOG.info(f"Setting env variable {k}={v}")
373427
os.environ[k] = str(v)
374428

375-
# Used by pytest only
376-
# copy.pop('checks', None)
377-
378429
return copy
379430

380431

0 commit comments

Comments
 (0)