|
3 | 3 | from pathlib import Path |
4 | 4 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union |
5 | 5 |
|
6 | | -import joblib |
7 | 6 | from lightning_utilities import module_available |
8 | 7 |
|
9 | 8 | from litmodels.io.cloud import download_model_files, upload_model_files |
| 9 | +from litmodels.io.utils import dump_pickle, load_pickle |
10 | 10 |
|
11 | 11 | if module_available("torch"): |
12 | 12 | import torch |
@@ -56,7 +56,7 @@ def upload_model( |
56 | 56 | torch.save(model.state_dict(), path) |
57 | 57 | else: |
58 | 58 | path = os.path.join(staging_dir, f"{model.__class__.__name__}.pkl") |
59 | | - joblib.dump(model, path) |
| 59 | + dump_pickle(model=model, path=path) |
60 | 60 |
|
61 | 61 | return upload_model_files( |
62 | 62 | path=path, |
@@ -111,7 +111,7 @@ def load_model(name: str, download_dir: str = ".") -> Any: |
111 | 111 | raise NotImplementedError("Downloaded model with multiple files is not supported yet.") |
112 | 112 | model_path = Path(download_dir) / download_paths[0] |
113 | 113 | if model_path.suffix.lower() == ".pkl": |
114 | | - return joblib.load(model_path) |
| 114 | + return load_pickle(path=model_path) |
115 | 115 | if model_path.suffix.lower() == ".ts": |
116 | 116 | return torch.jit.load(model_path) |
117 | 117 | raise NotImplementedError(f"Loading model from {model_path.suffix} is not supported yet.") |
0 commit comments