Skip to content

Commit b0f76ef

Browse files
authored
fix joblib model loading (#542)
1 parent 3d26aa7 commit b0f76ef

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

m2cgen/cli.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
Model can also be piped:
88
# cat <path_to_file> | m2cgen --language java
99
"""
10-
import pickle
1110
import sys
1211
from argparse import ArgumentParser, FileType
1312
from inspect import signature
@@ -99,6 +98,13 @@
9998
"--version", "-v",
10099
action="version",
101100
version=f"%(prog)s {m2cgen.__version__}")
101+
parser.add_argument(
102+
"--pickle-lib", "-pl",
103+
type=str,
104+
dest="lib",
105+
help="Sets the lib used to save the model",
106+
choices=["pickle", "joblib"],
107+
default="pickle")
102108

103109

104110
def parse_args(args):
@@ -109,7 +115,8 @@ def generate_code(args):
109115
sys.setrecursionlimit(args.recursion_limit)
110116

111117
with args.infile as f:
112-
model = pickle.load(f)
118+
pickle_lib = __import__(args.lib)
119+
model = pickle_lib.load(f)
113120

114121
exporter, supported_args = LANGUAGE_TO_EXPORTER[args.language]
115122

tests/test_cli.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ def _get_mock_args(
1717
package_name=None,
1818
class_name=None,
1919
infile=None,
20-
language=None
20+
language=None,
21+
lib="pickle"
2122
):
2223
return mock.MagicMock(
2324
indent=indent,
@@ -28,7 +29,8 @@ def _get_mock_args(
2829
class_name=class_name,
2930
infile=infile,
3031
language=language,
31-
recursion_limit=cli.MAX_RECURSION_DEPTH)
32+
recursion_limit=cli.MAX_RECURSION_DEPTH,
33+
lib=lib)
3234

3335

3436
def test_file_as_input(tmp_path):
@@ -122,6 +124,13 @@ def test_namespace(pickled_model):
122124
assert "namespace Tests.ML {" in generated_code
123125

124126

127+
def test_joblib_loading(pickled_model):
128+
mock_args = _get_mock_args(infile=pickled_model, language="go", lib="joblib")
129+
generated_code = cli.generate_code(mock_args).strip()
130+
131+
assert generated_code.startswith("func score(input []float64) float64 {\n")
132+
133+
125134
def test_indent(pickled_model):
126135
mock_args = _get_mock_args(infile=pickled_model, indent=0, language="c_sharp")
127136
generated_code = cli.generate_code(mock_args).strip()

0 commit comments

Comments
 (0)