Skip to content

Commit be9ff8b

Browse files
japolsHCookie
andauthored
fix: predict_step shard shapes (#692)
## Description Shard shapes fix when gathering in predict_step. ## What problem does this change solve? <!-- Describe if it's a bugfix, new feature, doc update, or breaking change --> ## What issue or task does this change relate to? <!-- link to Issue Number --> ## Additional notes ## <!-- Include any additional information, caveats, or considerations that the reviewer should be aware of. --> ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md) Co-authored-by: Harrison Cook <[email protected]>
1 parent aeaf00b commit be9ff8b

File tree

1 file changed

+3
-1
lines changed
  • models/src/anemoi/models/models

1 file changed

+3
-1
lines changed

models/src/anemoi/models/models/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from anemoi.models.distributed.graph import gather_tensor
2323
from anemoi.models.distributed.graph import shard_tensor
24+
from anemoi.models.distributed.shapes import apply_shard_shapes
2425
from anemoi.models.distributed.shapes import get_shard_shapes
2526
from anemoi.models.layers.bounding import build_boundings
2627
from anemoi.models.layers.graph import NamedNodesAttributes
@@ -226,6 +227,7 @@ def predict_step(
226227

227228
# Gather output if needed
228229
if gather_out and model_comm_group is not None:
229-
y_hat = gather_tensor(y_hat, -2, grid_shard_shapes, model_comm_group)
230+
y_hat_shard_shapes = apply_shard_shapes(y_hat, -2, grid_shard_shapes)
231+
y_hat = gather_tensor(y_hat, -2, y_hat_shard_shapes, model_comm_group)
230232

231233
return y_hat

0 commit comments

Comments
 (0)