Skip to content

Commit 0fadf18

Browse files
committed
PytorchLightningModelCheckpoint
1 parent 89bcefd commit 0fadf18

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

src/litmodels/integrations/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@
1010
__all__ += ["LightningModelCheckpoint"]
1111

1212
if _PYTORCHLIGHTNING_AVAILABLE:
13-
from litmodels.integrations.checkpoints import PTLightningModelCheckpoint
13+
from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint
1414

1515
__all__ += ["PTLightningModelCheckpoint"]

src/litmodels/integrations/checkpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def _save_checkpoint(self, trainer: "Trainer", filepath: str) -> None:
6767

6868
if _PYTORCHLIGHTNING_AVAILABLE:
6969

70-
class PTLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightningModelCheckpoint):
70+
class PytorchLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightningModelCheckpoint):
7171
"""PyTorch Lightning ModelCheckpoint with LitModel support.
7272
7373
Args:

tests/integrations/test_checkpoints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, importing,
2323
from lightning.pytorch.demos.boring_classes import BoringModel
2424
from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint
2525
elif importing == "pytorch_lightning":
26-
from litmodels.integrations.checkpoints import PTLightningModelCheckpoint as LitModelCheckpoint
26+
from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint
2727
from pytorch_lightning import Trainer
2828
from pytorch_lightning.callbacks import ModelCheckpoint
2929
from pytorch_lightning.demos.boring_classes import BoringModel
@@ -64,7 +64,7 @@ def test_lightning_checkpointing_pickleable(mock_auth, importing):
6464
if importing == "lightning":
6565
from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint
6666
elif importing == "pytorch_lightning":
67-
from litmodels.integrations.checkpoints import PTLightningModelCheckpoint as LitModelCheckpoint
67+
from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint
6868

6969
ckpt = LitModelCheckpoint(model_name="org-name/teamspace/model-name")
7070
pickle.dumps(ckpt)

0 commit comments

Comments
 (0)