diff --git a/src/litmodels/integrations/checkpoints.py b/src/litmodels/integrations/checkpoints.py index a1cbc90..9b5aebb 100644 --- a/src/litmodels/integrations/checkpoints.py +++ b/src/litmodels/integrations/checkpoints.py @@ -60,7 +60,9 @@ def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: def _save_checkpoint(self, trainer: "Trainer", filepath: str) -> None: super()._save_checkpoint(trainer, filepath) - self._upload_model(filepath) + if trainer.is_global_zero: + # Only upload from the main process + self._upload_model(filepath) if _PYTORCHLIGHTNING_AVAILABLE: @@ -81,4 +83,6 @@ def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: def _save_checkpoint(self, trainer: "Trainer", filepath: str) -> None: super()._save_checkpoint(trainer, filepath) - self._upload_model(filepath) + if trainer.is_global_zero: + # Only upload from the main process + self._upload_model(filepath)