|
1 | | -from typing import Any, Type, TypeVar, cast |
| 1 | +from typing import TYPE_CHECKING, Any |
2 | 2 |
|
3 | 3 | from lightning_sdk.lightning_cloud.login import Auth |
4 | 4 | from lightning_utilities.core.rank_zero import rank_zero_only |
|
7 | 7 | from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE |
8 | 8 |
|
9 | 9 | if _LIGHTNING_AVAILABLE: |
10 | | - from lightning.pytorch import Trainer |
11 | | - from lightning.pytorch.callbacks import ModelCheckpoint as LightningModelCheckpoint |
| 10 | + from lightning.pytorch.callbacks import ModelCheckpoint as _LightningModelCheckpoint |
| 11 | + |
| 12 | + if TYPE_CHECKING: |
| 13 | + from lightning.pytorch import Trainer |
| 14 | + |
| 15 | + |
12 | 16 | if _PYTORCHLIGHTNING_AVAILABLE: |
13 | | - from pytorch_lightning import Trainer |
14 | | - from pytorch_lightning.callbacks import ModelCheckpoint as PytorchLightningModelCheckpoint |
| 17 | + from pytorch_lightning.callbacks import ModelCheckpoint as _PytorchLightningModelCheckpoint |
| 18 | + |
| 19 | + if TYPE_CHECKING: |
| 20 | + from pytorch_lightning import Trainer |
| 21 | + |
15 | 22 |
|
| 23 | +# Base class to be inherited |
| 24 | +class LitModelCheckpointMixin: |
| 25 | + """Mixin class for LitModel checkpoint functionality.""" |
16 | 26 |
|
17 | | -# Type variable for the ModelCheckpoint class |
18 | | -ModelCheckpointType = TypeVar("ModelCheckpointType") |
| 27 | + def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: |
| 28 | + """Initialize with model name.""" |
| 29 | + self.model_name = model_name |
19 | 30 |
|
| 31 | + try: # authenticate before anything else starts |
| 32 | + auth = Auth() |
| 33 | + auth.authenticate() |
| 34 | + except Exception: |
| 35 | + raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.") |
20 | 36 |
|
21 | | -def _model_checkpoint_template(checkpoint_cls: Type[ModelCheckpointType]) -> Type[ModelCheckpointType]: |
22 | | - """Template function that returns a LitModelCheckpoint class for a specific ModelCheckpoint class. |
| 37 | + @rank_zero_only |
| 38 | + def _upload_model(self, filepath: str) -> None: |
| 39 | + # todo: uploading on background so training does nt stops |
| 40 | + # todo: use filename as version but need to validate that such version does not exists yet |
| 41 | + upload_model(name=self.model_name, model=filepath) |
23 | 42 |
|
24 | | - Args: |
25 | | - checkpoint_cls: The ModelCheckpoint class to extend |
26 | 43 |
|
27 | | - Returns: |
28 | | - A LitModelCheckpoint class extending the given ModelCheckpoint class |
29 | | - """ |
| 44 | +# Create specific implementations |
| 45 | +if _LIGHTNING_AVAILABLE: |
30 | 46 |
|
31 | | - class LitModelCheckpointTemplate(checkpoint_cls): # type: ignore |
| 47 | + class LightningModelCheckpoint(LitModelCheckpointMixin, _LightningModelCheckpoint): |
32 | 48 | """Lightning ModelCheckpoint with LitModel support. |
33 | 49 |
|
34 | 50 | Args: |
35 | | - model_name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname' |
36 | | - where entity is either your username or the name of an organization you are part of. |
| 51 | + model_name: Name of the model to upload in format 'organization/teamspace/modelname' |
37 | 52 | args: Additional arguments to pass to the parent class. |
38 | 53 | kwargs: Additional keyword arguments to pass to the parent class. |
39 | 54 | """ |
40 | 55 |
|
41 | 56 | def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: |
42 | | - """Initialize the LitModelCheckpoint.""" |
43 | | - super().__init__(*args, **kwargs) |
44 | | - self.model_name = model_name |
45 | | - |
46 | | - try: # authenticate before anything else starts |
47 | | - auth = Auth() |
48 | | - auth.authenticate() |
49 | | - except Exception: |
50 | | - raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.") |
51 | | - |
52 | | - @rank_zero_only |
53 | | - def _upload_model(self, filepath: str) -> None: |
54 | | - # todo: uploading on background so training does nt stops |
55 | | - # todo: use filename as version but need to validate that such version does not exists yet |
56 | | - upload_model(name=self.model_name, model=filepath) |
57 | | - |
58 | | - def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None: |
| 57 | + """Initialize the checkpoint with model name and other parameters.""" |
| 58 | + _LightningModelCheckpoint.__init__(self, *args, **kwargs) |
| 59 | + LitModelCheckpointMixin.__init__(self, model_name) |
| 60 | + |
| 61 | + def _save_checkpoint(self, trainer: "Trainer", filepath: str) -> None: |
59 | 62 | super()._save_checkpoint(trainer, filepath) |
60 | 63 | self._upload_model(filepath) |
61 | 64 |
|
62 | | - return cast(Type[ModelCheckpointType], LitModelCheckpointTemplate) |
63 | | - |
64 | 65 |
|
65 | | -# Create explicit classes with specific names |
66 | | -if _LIGHTNING_AVAILABLE: |
67 | | - LightningModelCheckpoint = _model_checkpoint_template(LightningModelCheckpoint) |
68 | 66 | if _PYTORCHLIGHTNING_AVAILABLE: |
69 | | - PTLightningModelCheckpoint = _model_checkpoint_template(PytorchLightningModelCheckpoint) |
| 67 | + |
| 68 | + class PTLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightningModelCheckpoint): |
| 69 | + """PyTorch Lightning ModelCheckpoint with LitModel support. |
| 70 | +
|
| 71 | + Args: |
| 72 | + model_name: Name of the model to upload in format 'organization/teamspace/modelname' |
| 73 | + args: Additional arguments to pass to the parent class. |
| 74 | + kwargs: Additional keyword arguments to pass to the parent class. |
| 75 | + """ |
| 76 | + |
| 77 | + def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: |
| 78 | + """Initialize the checkpoint with model name and other parameters.""" |
| 79 | + _PytorchLightningModelCheckpoint.__init__(self, *args, **kwargs) |
| 80 | + LitModelCheckpointMixin.__init__(self, model_name) |
| 81 | + |
| 82 | + def _save_checkpoint(self, trainer: "Trainer", filepath: str) -> None: |
| 83 | + super()._save_checkpoint(trainer, filepath) |
| 84 | + self._upload_model(filepath) |
0 commit comments