Skip to content

Commit 921e108

Browse files
JPXKQXpre-commit-ci[bot]ssmmnn11anaprietonem
authored
feat(training)!: remove support for EDA (#651)
## Description <!-- What issue or task does this change relate to? --> This PR removes the EDA functionality, simplifying the logic of the datasets and EnsForecast and paving the way for easier integration with multiple datasets. ## What problem does this change solve? <!-- Describe if it's a bugfix, new feature, doc update, or breaking change --> ## What issue or task does this change relate to? <!-- link to Issue Number --> ## Additional notes ## <!-- Include any additional information, caveats, or considerations that the reviewer should be aware of. --> ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md) <!-- readthedocs-preview anemoi-training start --> ---- 📚 Documentation preview 📚: https://anemoi-training--651.org.readthedocs.build/en/651/ <!-- readthedocs-preview anemoi-training end --> <!-- readthedocs-preview anemoi-graphs start --> ---- 📚 Documentation preview 📚: https://anemoi-graphs--651.org.readthedocs.build/en/651/ <!-- readthedocs-preview anemoi-graphs end --> <!-- readthedocs-preview anemoi-models start --> ---- 📚 Documentation preview 📚: https://anemoi-models--651.org.readthedocs.build/en/651/ <!-- readthedocs-preview anemoi-models end --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Lang <[email protected]> Co-authored-by: Ana Prieto Nemesio <[email protected]>
1 parent c1bbcec commit 921e108

34 files changed

+185
-631
lines changed

models/src/anemoi/models/migrations/migrator.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,12 +264,12 @@ class MissingAttribute:
264264
"""Placeholder type when encountering ImportError or AttributeError in Unpickler.find_class"""
265265

266266

267-
def _get_unpickler(replace_attrs: list[str] | bool = False):
267+
def _get_unpickler(replace_attrs: dict[str, list[str]] | bool = False):
268268
"""Get the Unpickler
269269
270270
Parameters
271271
----------
272-
replace_attrs : list[str] | bool, default False
272+
replace_attrs : dict[str,list[str]] | bool, default False
273273
Replace the provided attrs by a ``MissingAttribute`` object. If False, Fill not
274274
try to replace attributes. If True, will replace every missing attribute. You can use
275275
* as a wildcard to be replaced by any attribute in a module.
@@ -289,11 +289,26 @@ def find_class(self, module_name: str, global_name: str, /) -> Any:
289289
try:
290290
return super().find_class(module_name, global_name)
291291
except (ImportError, AttributeError) as e:
292+
293+
deleted_modules: list[str] = []
294+
deleted_attributes: list[str] = []
295+
296+
# --- Normalize replace_attrs ---
297+
if isinstance(replace_attrs, dict):
298+
deleted_modules = replace_attrs.get("deleted_modules", [])
299+
deleted_attributes = replace_attrs.get("deleted_attributes", [])
300+
292301
attr_name = f"{module_name}.{global_name}"
293302
wild_name = f"{module_name}.*"
303+
294304
if replace_attrs is False:
295305
raise e
296-
if replace_attrs is True or attr_name in replace_attrs or wild_name in replace_attrs:
306+
if (
307+
replace_attrs is True
308+
or attr_name in deleted_attributes
309+
or module_name in deleted_modules
310+
or wild_name in replace_attrs
311+
):
297312
LOGGER.debug("Missing attribute %s.%s is checkpoint. Ignoring.", module_name, global_name)
298313
return MissingAttribute
299314
raise e
@@ -308,7 +323,7 @@ class UnpicklerWrapper:
308323
return UnpicklerWrapper
309324

310325

311-
def _load_ckpt(path: str | PathLike, replace_attrs: list[str] | bool = False) -> CkptType:
326+
def _load_ckpt(path: str | PathLike, replace_attrs: dict[str, list[str]] | bool = False) -> CkptType:
312327
"""Loads a checkpoint
313328
314329
Parameters
@@ -570,6 +585,11 @@ def _resolve_context(self, context: MigrationContext) -> None:
570585
context : MigrationContext
571586
The context object
572587
"""
588+
for module_path in getattr(context, "deleted_modules", []):
589+
if module_path in sys.modules:
590+
LOGGER.debug("Delete module %s.", module_path)
591+
del sys.modules[module_path]
592+
573593
for module_path_end, module_path_start in context.module_paths.items():
574594
LOGGER.debug("Move module %s to %s.", module_path_start, module_path_end)
575595
sys.modules[module_path_start] = sys.modules[module_path_end]
@@ -614,13 +634,14 @@ def sync(self, path: str | PathLike) -> tuple[CkptType, CkptType, list[BaseOp]]:
614634
compatible_migrations = self._grouped_migrations[-1]
615635
self._check_executed_migrations(ckpt, compatible_migrations)
616636
setups, ops = self._resolve_operations(ckpt, compatible_migrations)
617-
replace_attrs: list[str] = []
637+
replace_attrs: dict[str, list[str]] = {}
618638
if len(setups):
619639
context = MigrationContext()
620640
for setup in setups:
621641
setup(context)
622642
self._resolve_context(context)
623-
replace_attrs = context.deleted_attributes
643+
replace_attrs["deleted_modules"] = context.deleted_modules
644+
replace_attrs["deleted_attributes"] = context.deleted_attributes
624645
# Force reloading checkpoint without obfuscating import issues.
625646
ckpt = _load_ckpt(path, replace_attrs)
626647
ckpt["hyper_parameters"]["metadata"].setdefault("migrations", {}).setdefault("history", [])
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# (C) Copyright 2025 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
from anemoi.models.migrations import CkptType
11+
from anemoi.models.migrations import MigrationContext
12+
from anemoi.models.migrations import MigrationMetadata
13+
14+
# DO NOT CHANGE -->
15+
metadata = MigrationMetadata(
16+
versions={
17+
"migration": "1.0.0",
18+
"anemoi-models": "%NEXT_ANEMOI_MODELS_VERSION%",
19+
},
20+
)
21+
# <-- END DO NOT CHANGE
22+
23+
24+
def migrate_setup(context: MigrationContext) -> None:
25+
"""Migrate setup callback to be run before loading the checkpoint.
26+
27+
Parameters
28+
----------
29+
context : MigrationContext
30+
A MigrationContext instance
31+
"""
32+
context.delete_module("anemoi.training.schemas.datamodule")
33+
34+
35+
def migrate(ckpt: CkptType) -> CkptType:
36+
"""Migrate the checkpoint.
37+
38+
Parameters
39+
----------
40+
ckpt : CkptType
41+
The checkpoint dict.
42+
43+
Returns
44+
-------
45+
CkptType
46+
The migrated checkpoint dict.
47+
"""
48+
return ckpt
49+
50+
51+
def rollback(ckpt: CkptType) -> CkptType:
52+
"""Rollback the checkpoint.
53+
54+
Parameters
55+
----------
56+
ckpt : CkptType
57+
The checkpoint dict.
58+
59+
Returns
60+
-------
61+
CkptType
62+
The rollbacked checkpoint dict.
63+
"""
64+
return ckpt

models/src/anemoi/models/migrations/setup_context.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(self) -> None:
2929
self.attribute_paths: dict[str, str] = {}
3030
self.module_paths: dict[str, str] = {}
3131
self.deleted_attributes: list[str] = []
32+
self.deleted_modules: list[str] = []
3233

3334
def delete_attribute(self, path: str) -> None:
3435
"""Indicate that an attribute has been deleted. Any class referencing this module will
@@ -41,6 +42,10 @@ def delete_attribute(self, path: str) -> None:
4142
"""
4243
self.deleted_attributes.append(path)
4344

45+
def delete_module(self, path: str) -> None:
46+
"""Mark a module for deletion."""
47+
self.deleted_modules.append(path)
48+
4449
def move_attribute(self, path_start: str, path_end: str) -> None:
4550
"""Move and rename an attribute between modules.
4651
@@ -76,6 +81,7 @@ class SerializedMigrationContext(TypedDict):
7681
attribute_paths: dict[str, str]
7782
module_paths: dict[str, str]
7883
deleted_attributes: list[str]
84+
deleted_modules: list[str]
7985

8086

8187
def serialize_setup_callback(setup: Callable[[MigrationContext], None]) -> SerializedMigrationContext:
@@ -98,6 +104,7 @@ def serialize_setup_callback(setup: Callable[[MigrationContext], None]) -> Seria
98104
"attribute_paths": ctx.attribute_paths,
99105
"module_paths": ctx.module_paths,
100106
"deleted_attributes": ctx.deleted_attributes,
107+
"deleted_modules": ctx.deleted_modules,
101108
}
102109

103110

@@ -112,6 +119,8 @@ def __call__(self, context: MigrationContext) -> None:
112119
context.delete_attribute(deleted_attribute)
113120
for path_end, path_start in self._ctx["attribute_paths"].items():
114121
context.move_attribute(path_start, path_end)
122+
for deleted_module in self._ctx["deleted_modules"].items():
123+
context.delete_module(deleted_module)
115124
for path_end, path_start in self._ctx["module_paths"].items():
116125
context.move_module(path_start, path_end)
117126

models/src/anemoi/models/models/ens_encoder_processor_decoder.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,38 +15,18 @@
1515
import torch
1616
from hydra.utils import instantiate
1717
from torch.distributed.distributed_c10d import ProcessGroup
18-
from torch_geometric.data import HeteroData
1918

2019
from anemoi.models.distributed.graph import shard_tensor
2120
from anemoi.models.distributed.shapes import get_or_apply_shard_shapes
2221
from anemoi.models.distributed.shapes import get_shard_shapes
2322
from anemoi.models.models import AnemoiModelEncProcDec
24-
from anemoi.utils.config import DotDict
2523

2624
LOGGER = logging.getLogger(__name__)
2725

2826

2927
class AnemoiEnsModelEncProcDec(AnemoiModelEncProcDec):
3028
"""Message passing graph neural network with ensemble functionality."""
3129

32-
def __init__(
33-
self,
34-
*,
35-
model_config: DotDict,
36-
data_indices: dict,
37-
statistics: dict,
38-
graph_data: HeteroData,
39-
truncation_data: dict,
40-
) -> None:
41-
42-
super().__init__(
43-
model_config=model_config,
44-
data_indices=data_indices,
45-
statistics=statistics,
46-
graph_data=graph_data,
47-
truncation_data=truncation_data,
48-
)
49-
5030
def _calculate_input_dim(self):
5131
base_input_dim = super()._calculate_input_dim()
5232
return base_input_dim + self.num_input_channels_prognostic + 1

training/docs/user-guide/configuring.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ settings at the top as follows:
2222
- data: zarr
2323
- dataloader: native_grid
2424
- diagnostics: evaluation
25-
- datamodule: single
2625
- hardware: example
2726
- graph: multi_scale
2827
- model: gnn
@@ -102,7 +101,6 @@ match the dataset you provide.
102101
- data: zarr
103102
- dataloader: native_grid
104103
- diagnostics: evaluation
105-
- datamodule: single
106104
- hardware: example
107105
- graph: multi_scale
108106
- model: transformer # Change from default group

training/docs/user-guide/diffusion-set-up.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ A minimal config file for standard diffusion training:
226226
- data: zarr
227227
- dataloader: native_grid
228228
- diagnostics: evaluation
229-
- datamodule: single
230229
- hardware: example
231230
- graph: multi_scale
232231
- model: graphtransformer_diffusion # Use diffusion model
@@ -247,7 +246,6 @@ For tendency-based diffusion, change the model config and model task:
247246
- data: zarr
248247
- dataloader: native_grid
249248
- diagnostics: evaluation
250-
- datamodule: single
251249
- hardware: example
252250
- graph: multi_scale
253251
- model: graphtransformer_diffusiontend # Use tendency diffusion model

training/docs/user-guide/yaml/example_crps_config.yaml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@ hardware:
2020
num_nodes: 1
2121
num_gpus_per_model: 1
2222

23-
# Changes in datamodule
24-
datamodule:
25-
_target_: anemoi.training.data.datamodule.AnemoiEnsDatasetsDataModule
26-
2723
data:
2824
resolution: o96
2925

training/src/anemoi/training/config/config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ defaults:
22
- data: zarr
33
- dataloader: native_grid
44
- diagnostics: evaluation
5-
- datamodule: single
65
- hardware: example
76
- graph: multi_scale
87
- model: gnn

training/src/anemoi/training/config/datamodule/ens.yaml

Lines changed: 0 additions & 1 deletion
This file was deleted.

training/src/anemoi/training/config/datamodule/single.yaml

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)