File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed
models/src/anemoi/models/models Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -160,17 +160,17 @@ def _assert_valid_sharding(
160160
161161 def forward (
162162 self ,
163- x : Tensor ,
163+ x : dict [ str , Tensor ] ,
164164 * ,
165165 model_comm_group : Optional [ProcessGroup ] = None ,
166166 grid_shard_shapes : dict [str , list ] | None = None ,
167167 ** kwargs ,
168- ) -> Tensor :
168+ ) -> dict [ str , Tensor ] :
169169 """Forward pass of the model.
170170
171171 Parameters
172172 ----------
173- x : Tensor
173+ x : dict[str, Tensor]
174174 Input data
175175 model_comm_group : Optional[ProcessGroup], optional
176176 Model communication group, by default None
@@ -179,7 +179,7 @@ def forward(
179179
180180 Returns
181181 -------
182- Tensor
182+ dict[str, Tensor]
183183 Output of the model, with the same shape as the input (sharded if input is sharded)
184184 """
185185 # Multi-dataset case
You can’t perform that action at this time.
0 commit comments