Skip to content

Commit 3c663b8

Browse files
update pickling with joblib (#87)
* update pickling with `joblib` * usage --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 93b8ab6 commit 3c663b8

File tree

6 files changed

+50
-15
lines changed

6 files changed

+50
-15
lines changed

.github/workflows/ci-testing.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ jobs:
2121
requires: ["latest"]
2222
dependency: ["lightning"]
2323
include:
24-
- { os: "ubuntu-22.04", python-version: "3.9", requires: "oldest", dependency: "lightning" }
25-
- { os: "ubuntu-24.04", python-version: "3.10", requires: "latest", dependency: "pytorch_lightning" }
26-
- { os: "ubuntu-24.04", python-version: "3.12", requires: "latest", dependency: "pytorch_lightning" }
27-
- { os: "macOS-13", python-version: "3.12", requires: "latest", dependency: "pytorch_lightning" }
24+
- { requires: "oldest", dependency: "lightning", os: "ubuntu-22.04", python-version: "3.9" }
25+
- { requires: "latest", dependency: "pytorch_lightning", os: "ubuntu-24.04", python-version: "3.12" }
26+
- { requires: "latest", dependency: "pytorch_lightning", os: "windows-2022", python-version: "3.12" }
27+
- { requires: "latest", dependency: "pytorch_lightning", os: "macOS-13", python-version: "3.12" }
2828

2929
# Timeout: https://stackoverflow.com/a/59076067/4521646
3030
timeout-minutes: 35

_requirements/extra.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
lightning >= 2.0.0
22
numpy <2.0.0 ; platform_system == "Darwin" # compatibility fix for Torch
3+
joblib >= 1.0.0

requirements.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,2 @@
1-
# NOTE: once we add more dependencies, consider update dependabot to check for updates
2-
31
lightning-sdk >=0.2.9
42
lightning-utilities
5-
joblib

src/litmodels/integrations/mixins.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import inspect
22
import json
3-
import pickle
43
import tempfile
54
import warnings
65
from abc import ABC
@@ -10,6 +9,7 @@
109
from lightning_utilities.core.rank_zero import rank_zero_warn
1110

1211
from litmodels.io.cloud import download_model_files, upload_model_files
12+
from litmodels.io.utils import dump_pickle, load_pickle
1313

1414
if TYPE_CHECKING:
1515
import torch
@@ -89,8 +89,7 @@ def upload_model(
8989
"""
9090
name, model_name, temp_folder = self._setup(name, temp_folder)
9191
pickle_path = Path(temp_folder) / f"{model_name}.pkl"
92-
with open(pickle_path, "wb") as fp:
93-
pickle.dump(self, fp, protocol=pickle.HIGHEST_PROTOCOL)
92+
dump_pickle(model=self, path=pickle_path)
9493
if version:
9594
name = f"{name}:{version}"
9695
self._upload_model_files(name=name, path=pickle_path, metadata=metadata)
@@ -116,8 +115,7 @@ def download_model(
116115
if len(pkl_files) > 1:
117116
raise RuntimeError(f"Multiple pickle files found for model: {model_registry} with {pkl_files}")
118117
pkl_path = Path(temp_folder) / pkl_files[0]
119-
with open(pkl_path, "rb") as fp:
120-
obj = pickle.load(fp)
118+
obj = load_pickle(path=pkl_path)
121119
if not isinstance(obj, cls):
122120
raise RuntimeError(f"Unpickled object is not of type {cls.__name__}: {type(obj)}")
123121
return obj

src/litmodels/io/gateway.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from pathlib import Path
44
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
55

6-
import joblib
76
from lightning_utilities import module_available
87

98
from litmodels.io.cloud import download_model_files, upload_model_files
9+
from litmodels.io.utils import dump_pickle, load_pickle
1010

1111
if module_available("torch"):
1212
import torch
@@ -56,7 +56,7 @@ def upload_model(
5656
torch.save(model.state_dict(), path)
5757
else:
5858
path = os.path.join(staging_dir, f"{model.__class__.__name__}.pkl")
59-
joblib.dump(model, path)
59+
dump_pickle(model=model, path=path)
6060

6161
return upload_model_files(
6262
path=path,
@@ -111,7 +111,7 @@ def load_model(name: str, download_dir: str = ".") -> Any:
111111
raise NotImplementedError("Downloaded model with multiple files is not supported yet.")
112112
model_path = Path(download_dir) / download_paths[0]
113113
if model_path.suffix.lower() == ".pkl":
114-
return joblib.load(model_path)
114+
return load_pickle(path=model_path)
115115
if model_path.suffix.lower() == ".ts":
116116
return torch.jit.load(model_path)
117117
raise NotImplementedError(f"Loading model from {model_path.suffix} is not supported yet.")

src/litmodels/io/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pickle
2+
from pathlib import Path
3+
from typing import Any, Union
4+
5+
from lightning_utilities import module_available
6+
7+
if module_available("joblib"):
8+
import joblib
9+
else:
10+
joblib = None
11+
12+
13+
def dump_pickle(model: Any, path: Union[str, Path]) -> None:
14+
"""Dump a model to a pickle file.
15+
16+
Args:
17+
model: The model to be pickled.
18+
path: The path where the model will be saved.
19+
"""
20+
if joblib is not None:
21+
joblib.dump(model, filename=path, compress=7)
22+
else:
23+
with open(path, "wb") as fp:
24+
pickle.dump(model, fp, protocol=pickle.HIGHEST_PROTOCOL)
25+
26+
27+
def load_pickle(path: Union[str, Path]) -> Any:
28+
"""Load a model from a pickle file.
29+
30+
Args:
31+
path: The path to the pickle file.
32+
33+
Returns:
34+
The unpickled model.
35+
"""
36+
if joblib is not None:
37+
return joblib.load(path)
38+
with open(path, "rb") as fp:
39+
return pickle.load(fp)

0 commit comments

Comments
 (0)