Skip to content

Commit c1bbcec

Browse files
JPXKQXanaprietonem
andauthored
fix(models): assert no dropout (#638)
## Description <!-- What issue or task does this change relate to? --> This PR removes an extra `not` in the assertion checking the usage of dropout with multiple GPUs. <!-- Include any additional information, caveats, or considerations that the reviewer should be aware of. --> ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md) --------- Co-authored-by: Ana Prieto Nemesio <[email protected]>
1 parent 0c830e9 commit c1bbcec

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

models/src/anemoi/models/layers/processor.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def __init__(
5555

5656
self.layer_factory = load_layer_kernels(layer_kernels)
5757

58+
self._has_dropout = kwargs.get("dropout_p", 0.0) > 0 if "dropout_p" in kwargs else False
59+
5860
assert (
5961
num_layers % num_chunks == 0
6062
), f"Number of processor layers ({num_layers}) has to be divisible by the number of processor chunks ({num_chunks})."
@@ -83,6 +85,12 @@ def run_layers(self, data: tuple, *args, **kwargs) -> Tensor:
8385

8486
def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
8587
"""Example forward pass."""
88+
89+
if (model_comm_group := kwargs.get("model_comm_group", None)) is not None:
90+
assert (
91+
model_comm_group.size() == 1 or not self._has_dropout
92+
), f"Dropout is not supported when model is sharded across {model_comm_group.size()} GPUs"
93+
8694
x = self.run_layers((x,), *args, **kwargs)
8795
return x
8896

@@ -108,6 +116,7 @@ def __init__(
108116
num_chunks=num_chunks,
109117
cpu_offload=cpu_offload,
110118
layer_kernels=layer_kernels,
119+
dropout_p=dropout_p,
111120
)
112121

113122
self.build_layers(
@@ -121,8 +130,6 @@ def __init__(
121130

122131
self.offload_layers(cpu_offload)
123132

124-
self._has_dropout = dropout_p > 0 if dropout_p else False
125-
126133
def forward(
127134
self,
128135
x: Tensor,
@@ -136,11 +143,7 @@ def forward(
136143
if model_comm_group:
137144
assert (
138145
model_comm_group.size() == 1 or batch_size == 1
139-
), "Only batch size of 1 is supported when model is sharded accross GPUs"
140-
141-
assert (
142-
model_comm_group.size() > 1 and not self._has_dropout
143-
), "Dropout is not supported when model is sharded across GPUS"
146+
), f"Only batch size of 1 is supported when model is sharded accross {model_comm_group.size()} GPUs"
144147

145148
(x,) = self.run_layers((x,), shape_nodes, batch_size, model_comm_group, **kwargs)
146149

@@ -210,6 +213,7 @@ def __init__(
210213
num_heads=num_heads,
211214
mlp_hidden_ratio=mlp_hidden_ratio,
212215
layer_kernels=layer_kernels,
216+
dropout_p=dropout_p,
213217
)
214218

215219
self.build_layers(

0 commit comments

Comments
 (0)