Skip to content

Commit d97a4b6

Browse files
gmertesOpheliaMiralles
authored andcommitted
Update type hint
1 parent eeda349 commit d97a4b6

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

models/src/anemoi/models/models/encoder_processor_decoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,17 +160,17 @@ def _assert_valid_sharding(
160160

161161
def forward(
162162
self,
163-
x: Tensor,
163+
x: dict[str, Tensor],
164164
*,
165165
model_comm_group: Optional[ProcessGroup] = None,
166166
grid_shard_shapes: dict[str, list] | None = None,
167167
**kwargs,
168-
) -> Tensor:
168+
) -> dict[str, Tensor]:
169169
"""Forward pass of the model.
170170
171171
Parameters
172172
----------
173-
x : Tensor
173+
x : dict[str, Tensor]
174174
Input data
175175
model_comm_group : Optional[ProcessGroup], optional
176176
Model communication group, by default None
@@ -179,7 +179,7 @@ def forward(
179179
180180
Returns
181181
-------
182-
Tensor
182+
dict[str, Tensor]
183183
Output of the model, with the same shape as the input (sharded if input is sharded)
184184
"""
185185
# Multi-dataset case

0 commit comments

Comments
 (0)