|
1 | | -from typing import Any |
| 1 | +from typing import Any, Type, TypeVar, cast |
2 | 2 |
|
3 | 3 | from lightning_sdk.lightning_cloud.login import Auth |
4 | 4 |
|
|
7 | 7 |
|
8 | 8 | if _LIGHTNING_AVAILABLE: |
9 | 9 | from lightning.pytorch import Trainer |
10 | | - from lightning.pytorch.callbacks import ModelCheckpoint |
11 | | -elif _PYTORCHLIGHTNING_AVAILABLE: |
| 10 | + from lightning.pytorch.callbacks import ModelCheckpoint as LightningModelCheckpoint |
| 11 | +if _PYTORCHLIGHTNING_AVAILABLE: |
12 | 12 | from pytorch_lightning import Trainer |
13 | | - from pytorch_lightning.callbacks import ModelCheckpoint |
14 | | -else: |
15 | | - raise ModuleNotFoundError("No module named 'lightning' or 'pytorch_lightning'") |
| 13 | + from pytorch_lightning.callbacks import ModelCheckpoint as PytorchLightningModelCheckpoint |
16 | 14 |
|
17 | 15 |
|
18 | | -class LitModelCheckpoint(ModelCheckpoint): |
19 | | - """Lightning ModelCheckpoint with LitModel support. |
| 16 | +# Type variable for the ModelCheckpoint class |
| 17 | +ModelCheckpointType = TypeVar("ModelCheckpointType") |
| 18 | + |
| 19 | + |
| 20 | +def _model_checkpoint_template(checkpoint_cls: Type[ModelCheckpointType]) -> Type[ModelCheckpointType]: |
| 21 | + """Template function that returns a LitModelCheckpoint class for a specific ModelCheckpoint class. |
20 | 22 |
|
21 | 23 | Args: |
22 | | - model_name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname' |
23 | | - where entity is either your username or the name of an organization you are part of. |
24 | | - args: Additional arguments to pass to the parent class. |
25 | | - kwargs: Additional keyword arguments to pass to the parent class. |
| 24 | + checkpoint_cls: The ModelCheckpoint class to extend |
26 | 25 |
|
| 26 | + Returns: |
| 27 | + A LitModelCheckpoint class extending the given ModelCheckpoint class |
27 | 28 | """ |
28 | 29 |
|
29 | | - def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: |
30 | | - """Initialize the LitModelCheckpoint.""" |
31 | | - super().__init__(*args, **kwargs) |
32 | | - self.model_name = model_name |
33 | | - |
34 | | - try: |
35 | | - # authenticate before anything else starts |
36 | | - auth = Auth() |
37 | | - auth.authenticate() |
38 | | - except Exception: |
39 | | - raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.") |
40 | | - |
41 | | - def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None: |
42 | | - super()._save_checkpoint(trainer, filepath) |
43 | | - # todo: uploading on background so training does nt stops |
44 | | - # todo: use filename as version but need to validate that such version does not exists yet |
45 | | - upload_model(name=self.model_name, model=filepath) |
| 30 | + class LitModelCheckpointTemplate(checkpoint_cls): # type: ignore |
| 31 | + """Lightning ModelCheckpoint with LitModel support. |
| 32 | +
|
| 33 | + Args: |
| 34 | + model_name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname' |
| 35 | + where entity is either your username or the name of an organization you are part of. |
| 36 | + args: Additional arguments to pass to the parent class. |
| 37 | + kwargs: Additional keyword arguments to pass to the parent class. |
| 38 | + """ |
| 39 | + |
| 40 | + def __init__(self, model_name: str, *args: Any, **kwargs: Any) -> None: |
| 41 | + """Initialize the LitModelCheckpoint.""" |
| 42 | + super().__init__(*args, **kwargs) |
| 43 | + self.model_name = model_name |
| 44 | + |
| 45 | + try: |
| 46 | + # authenticate before anything else starts |
| 47 | + auth = Auth() |
| 48 | + auth.authenticate() |
| 49 | + except Exception: |
| 50 | + raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.") |
| 51 | + |
| 52 | + def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None: |
| 53 | + super()._save_checkpoint(trainer, filepath) |
| 54 | + # upload model after checkpoint is saved |
| 55 | + upload_model(name=self.model_name, model=filepath) |
| 56 | + |
| 57 | + return cast(Type[ModelCheckpointType], LitModelCheckpointTemplate) |
| 58 | + |
| 59 | + |
| 60 | +# Create explicit classes with specific names if needed |
| 61 | +if _LIGHTNING_AVAILABLE: |
| 62 | + LightningModelCheckpoint = _model_checkpoint_template(LightningModelCheckpoint) |
| 63 | +if _PYTORCHLIGHTNING_AVAILABLE: |
| 64 | + PTLightningModelCheckpoint = _model_checkpoint_template(PytorchLightningModelCheckpoint) |
0 commit comments