Skip to content

Commit f7d5ae4

Browse files
authored
fix: Remove import of anemoi training in compile (#705)
## Description removed a leftover import of anemoi.training in anemoi models. this created a circular dependancy which broke some downstream CI tests. I had to refactor the code, because now we import torc.nn.module and not pl.LightningModule, so the syntax is different ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md)
1 parent 200101e commit f7d5ae4

File tree

2 files changed

+5
-16
lines changed

2 files changed

+5
-16
lines changed

models/src/anemoi/models/utils/compile.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,14 @@
88
# nor does it submit to any jurisdiction.
99

1010
import logging
11-
from functools import reduce
1211
from importlib.util import find_spec
1312

1413
import torch
1514
import torch_geometric
1615
from hydra.utils import get_class
1716
from numpy import unique
1817
from omegaconf import DictConfig
19-
20-
from anemoi.training.train.tasks.base import BaseGraphModule
18+
from torch.nn import Module
2119

2220
LOGGER = logging.getLogger(__name__)
2321

@@ -58,7 +56,7 @@ def _meets_library_versions_for_compile() -> bool:
5856
return version_req and has_triton
5957

6058

61-
def mark_for_compilation(model: BaseGraphModule, compile_config: DictConfig | None) -> BaseGraphModule:
59+
def mark_for_compilation(model: Module, compile_config: DictConfig | None) -> Module:
6260
"""Marks modules within 'model' for compilation, according to 'compile_config'.
6361
6462
Modules are not compiled here. The compilation will occur
@@ -83,20 +81,11 @@ def mark_for_compilation(model: BaseGraphModule, compile_config: DictConfig | No
8381
options = entry.get("options", default_compile_options)
8482

8583
LOGGER.debug("%s will be compiled with the following options: %s", str(module), str(options))
86-
compiled_module = torch.compile(module, **options) # Note: the module is not compiled yet
84+
module.forward = torch.compile(module.forward, **options) # Note: the function is not compiled yet
8785
# It is just marked for JIT-compilation later
8886
# It will be compiled before its first forward pass
8987
compiled_modules.append(entry.module)
9088

91-
# Update the model with the new 'compiled' module
92-
# go from "anemoi.models.layers.conv.GraphTransformerConv"
93-
# to obj(anemoi.models.layers.conv)
94-
parts = name.split(".")
95-
parent = reduce(getattr, parts[:-1], model)
96-
# then set obj(anemoi.models.layers.conv).GrapTransformerConv = compiled_module
97-
LOGGER.debug("Replacing %s with a compiled version", str(parts[-1]))
98-
setattr(parent, parts[-1], compiled_module)
99-
10089
LOGGER.info("The following modules will be compiled: %s", str(unique(compiled_modules)))
10190

10291
return model

models/tests/utils/test_compile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_compile() -> None:
109109
result_compiled = ln_compiled.forward(x_in, cond)
110110

111111
# check the function was compiled
112-
assert hasattr(ln_compiled, "_compile_kwargs")
112+
assert hasattr(ln_compiled.forward, "_torchdynamo_orig_callable")
113113

114114
# check the result of the compiled function matches the uncompiled result
115115
assert torch.allclose(result, result_compiled)
@@ -145,7 +145,7 @@ def test_compile_layer_kernel() -> None:
145145
result_compiled = mhsa_compiled.forward(x, shapes, batch_size)
146146

147147
# check the function was compiled
148-
assert hasattr(mhsa_compiled.projection, "_compile_kwargs")
148+
assert hasattr(mhsa_compiled.projection.forward, "_torchdynamo_orig_callable")
149149

150150
# check the result of the compiled function matches the uncompiled result
151151
assert torch.allclose(result, result_compiled)

0 commit comments

Comments
 (0)