Skip to content

Commit b163ea9

Browse files
Tell FSDP2 about embedding engine forward functions (ecmwf#1133)
* Tell FSDP2 about embedding engine forward functions Note DO NOT add print functions in forward functions of the model, it will break with FSDP2 * Add comment
1 parent b4cc165 commit b163ea9

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

src/weathergen/train/trainer.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,16 @@ def init_model_and_shard(self, cf, devices):
239239
fully_shard(model)
240240
for tensor in itertools.chain(model.parameters(), model.buffers()):
241241
assert tensor.device == torch.device("meta")
242+
243+
# For reasons we do not yet fully understand, when using train continue in some
244+
# instances, FSDP2 does not register the forward_channels and forward_columns
245+
# functions in the embedding engine as forward functions. Thus, yielding a crash
246+
# because the input tensors are not converted to DTensors. This seems to primarily
247+
# occur during validation.
248+
for embed in model.embeds:
249+
torch.distributed.fsdp.register_fsdp_forward_method(embed, "forward_channels")
250+
torch.distributed.fsdp.register_fsdp_forward_method(embed, "forward_columns")
251+
242252
return model, model_params
243253

244254
def run(self, cf, devices, run_id_contd=None, epoch_contd=None):

0 commit comments

Comments
 (0)