Skip to content

Commit 0c830e9

Browse files
authored
fix: basemodel.predict_step (#672)
## Description We don't want to apply truncation in predict_step. ## What problem does this change solve? Bug breaking code when gather_out=True in predict_step. ## What issue or task does this change relate to? <!-- link to Issue Number --> ## Additional notes ## <!-- Include any additional information, caveats, or considerations that the reviewer should be aware of. --> ***As a contributor to the Anemoi framework, please ensure that your changes include unit tests, updates to any affected dependencies and documentation, and have been tested in a parallel setting (i.e., with multiple GPUs). As a reviewer, you are also responsible for verifying these aspects and requesting changes if they are not adequately addressed. For guidelines about those please refer to https://anemoi.readthedocs.io/en/latest/*** By opening this pull request, I affirm that all authors agree to the [Contributor License Agreement.](https://github.com/ecmwf/codex/blob/main/Legal/contributor_license_agreement.md)
1 parent dde3cb6 commit 0c830e9

File tree

1 file changed

+1
-1
lines changed
  • models/src/anemoi/models/models

1 file changed

+1
-1
lines changed

models/src/anemoi/models/models/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,6 @@ def predict_step(
228228

229229
# Gather output if needed
230230
if gather_out and model_comm_group is not None:
231-
y_hat = gather_tensor(y_hat, -2, self.truncation(y_hat, -2, grid_shard_shapes), model_comm_group)
231+
y_hat = gather_tensor(y_hat, -2, grid_shard_shapes, model_comm_group)
232232

233233
return y_hat

0 commit comments

Comments
 (0)