Skip to content

Commit 104212c

Browse files
VeraChristinapre-commit-ci[bot]
authored andcommitted
feat: metadata for multi datasets (ecmwf#762)
Update the metadata we store in checkpoints in order to provide more information to inference The checkpoint mentioned [here](ecmwf#594 (comment)) was created based on this branch. ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d97a4b6 commit 104212c

File tree

10 files changed

+248
-371
lines changed

10 files changed

+248
-371
lines changed

models/src/anemoi/models/interface/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(
7777
self.supporting_arrays = supporting_arrays if supporting_arrays is not None else {}
7878
self.data_indices = data_indices
7979
self._build_model()
80+
self._update_metadata()
8081

8182
def _build_processors_for_dataset(
8283
self, dataset_name: str, statistics: dict, data_indices: dict, statistics_tendencies: dict = None
@@ -207,3 +208,6 @@ def predict_step(
207208

208209
# Delegate to the model's predict_step implementation with processors
209210
return self.model.predict_step(**predict_kwargs, **kwargs)
211+
212+
def _update_metadata(self) -> None:
213+
self.model.fill_metadata(self.metadata)

models/src/anemoi/models/models/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,8 @@ def predict_step(
312312
y_hat[dataset_name] = gather_tensor(y_hat[dataset_name], -2, y_hat_shard_shapes, model_comm_group)
313313

314314
return y_hat
315+
316+
@abstractmethod
317+
def fill_metadata(self, md_dict) -> None:
318+
"""To be implemented in subclasses to fill model-specific metadata."""
319+
pass

models/src/anemoi/models/models/diffusion_encoder_processor_decoder.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,24 @@ def sample(
589589
grid_shard_shapes=grid_shard_shapes,
590590
)
591591

592+
def fill_metadata(self, md_dict) -> None:
593+
for dataset in self.input_dim.keys():
594+
shapes = {
595+
"variables": self.input_dim[dataset],
596+
"input_timesteps": self.multi_step,
597+
"ensemble": 1,
598+
"grid": None, # grid size is dynamic
599+
}
600+
md_dict["metadata_inference"][dataset]["shapes"] = shapes
601+
602+
rel_date_indices = md_dict["metadata_inference"][dataset]["timesteps"]["relative_date_indices_training"]
603+
input_rel_date_indices = rel_date_indices[:-1]
604+
output_rel_date_indices = rel_date_indices[-1]
605+
md_dict["metadata_inference"][dataset]["timesteps"]["input_relative_date_indices"] = input_rel_date_indices
606+
md_dict["metadata_inference"][dataset]["timesteps"][
607+
"output_relative_date_indices"
608+
] = output_rel_date_indices
609+
592610

593611
class AnemoiDiffusionTendModelEncProcDec(AnemoiDiffusionModelEncProcDec):
594612
"""Diffusion model for tendency prediction."""

models/src/anemoi/models/models/encoder_processor_decoder.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,3 +273,21 @@ def forward(
273273
)
274274

275275
return x_out_dict
276+
277+
def fill_metadata(self, md_dict) -> None:
278+
for dataset in self.input_dim.keys():
279+
shapes = {
280+
"variables": self.input_dim[dataset],
281+
"input_timesteps": self.multi_step,
282+
"ensemble": 1,
283+
"grid": None, # grid size is dynamic
284+
}
285+
md_dict["metadata_inference"][dataset]["shapes"] = shapes
286+
287+
rel_date_indices = md_dict["metadata_inference"][dataset]["timesteps"]["relative_date_indices_training"]
288+
input_rel_date_indices = rel_date_indices[:-1]
289+
output_rel_date_indices = rel_date_indices[-1]
290+
md_dict["metadata_inference"][dataset]["timesteps"]["input_relative_date_indices"] = input_rel_date_indices
291+
md_dict["metadata_inference"][dataset]["timesteps"][
292+
"output_relative_date_indices"
293+
] = output_rel_date_indices

0 commit comments

Comments
 (0)