|
3 | 3 | from pathlib import Path |
4 | 4 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union |
5 | 5 |
|
6 | | -from lightning_utilities import module_available |
7 | | - |
8 | 6 | from litmodels.io.cloud import download_model_files, upload_model_files |
9 | | -from litmodels.io.utils import dump_pickle, load_pickle |
| 7 | +from litmodels.io.utils import _KERAS_AVAILABLE, _PYTORCH_AVAILABLE, dump_pickle, load_pickle |
10 | 8 |
|
11 | | -if module_available("torch"): |
| 9 | +if _PYTORCH_AVAILABLE: |
12 | 10 | import torch |
13 | | -else: |
14 | | - torch = None |
| 11 | + |
| 12 | +if _KERAS_AVAILABLE: |
| 13 | + from tensorflow import keras |
15 | 14 |
|
16 | 15 | if TYPE_CHECKING: |
17 | 16 | from lightning_sdk.models import UploadedModelInfo |
@@ -48,12 +47,15 @@ def upload_model( |
48 | 47 | # if LightningModule and isinstance(model, LightningModule): |
49 | 48 | # path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt") |
50 | 49 | # model.save_checkpoint(path) |
51 | | - elif torch and isinstance(model, torch.jit.ScriptModule): |
| 50 | + elif _PYTORCH_AVAILABLE and isinstance(model, torch.jit.ScriptModule): |
52 | 51 | path = os.path.join(staging_dir, f"{model.__class__.__name__}.ts") |
53 | 52 | model.save(path) |
54 | | - elif torch and isinstance(model, torch.nn.Module): |
| 53 | + elif _PYTORCH_AVAILABLE and isinstance(model, torch.nn.Module): |
55 | 54 | path = os.path.join(staging_dir, f"{model.__class__.__name__}.pth") |
56 | 55 | torch.save(model.state_dict(), path) |
| 56 | + elif _KERAS_AVAILABLE and isinstance(model, keras.models.Model): |
| 57 | + path = os.path.join(staging_dir, f"{model.__class__.__name__}.keras") |
| 58 | + model.save(path) |
57 | 59 | else: |
58 | 60 | path = os.path.join(staging_dir, f"{model.__class__.__name__}.pkl") |
59 | 61 | dump_pickle(model=model, path=path) |
@@ -110,8 +112,10 @@ def load_model(name: str, download_dir: str = ".") -> Any: |
110 | 112 | if len(download_paths) > 1: |
111 | 113 | raise NotImplementedError("Downloaded model with multiple files is not supported yet.") |
112 | 114 | model_path = Path(download_dir) / download_paths[0] |
113 | | - if model_path.suffix.lower() == ".pkl": |
114 | | - return load_pickle(path=model_path) |
115 | 115 | if model_path.suffix.lower() == ".ts": |
116 | 116 | return torch.jit.load(model_path) |
| 117 | + if model_path.suffix.lower() == ".keras": |
| 118 | + return keras.models.load_model(model_path) |
| 119 | + if model_path.suffix.lower() == ".pkl": |
| 120 | + return load_pickle(path=model_path) |
117 | 121 | raise NotImplementedError(f"Loading model from {model_path.suffix} is not supported yet.") |
0 commit comments