File tree Expand file tree Collapse file tree 2 files changed +20
-4
lines changed
Expand file tree Collapse file tree 2 files changed +20
-4
lines changed Original file line number Diff line number Diff line change 77Model can also be piped:
88 # cat <path_to_file> | m2cgen --language java
99"""
10- import pickle
1110import sys
1211from argparse import ArgumentParser , FileType
1312from inspect import signature
9998 "--version" , "-v" ,
10099 action = "version" ,
101100 version = f"%(prog)s { m2cgen .__version__ } " )
101+ parser .add_argument (
102+ "--pickle-lib" , "-pl" ,
103+ type = str ,
104+ dest = "lib" ,
105+ help = "Sets the lib used to save the model" ,
106+ choices = ["pickle" , "joblib" ],
107+ default = "pickle" )
102108
103109
104110def parse_args (args ):
@@ -109,7 +115,8 @@ def generate_code(args):
109115 sys .setrecursionlimit (args .recursion_limit )
110116
111117 with args .infile as f :
112- model = pickle .load (f )
118+ pickle_lib = __import__ (args .lib )
119+ model = pickle_lib .load (f )
113120
114121 exporter , supported_args = LANGUAGE_TO_EXPORTER [args .language ]
115122
Original file line number Diff line number Diff line change @@ -17,7 +17,8 @@ def _get_mock_args(
1717 package_name = None ,
1818 class_name = None ,
1919 infile = None ,
20- language = None
20+ language = None ,
21+ lib = "pickle"
2122):
2223 return mock .MagicMock (
2324 indent = indent ,
@@ -28,7 +29,8 @@ def _get_mock_args(
2829 class_name = class_name ,
2930 infile = infile ,
3031 language = language ,
31- recursion_limit = cli .MAX_RECURSION_DEPTH )
32+ recursion_limit = cli .MAX_RECURSION_DEPTH ,
33+ lib = lib )
3234
3335
3436def test_file_as_input (tmp_path ):
@@ -122,6 +124,13 @@ def test_namespace(pickled_model):
122124 assert "namespace Tests.ML {" in generated_code
123125
124126
127+ def test_joblib_loading (pickled_model ):
128+ mock_args = _get_mock_args (infile = pickled_model , language = "go" , lib = "joblib" )
129+ generated_code = cli .generate_code (mock_args ).strip ()
130+
131+ assert generated_code .startswith ("func score(input []float64) float64 {\n " )
132+
133+
125134def test_indent (pickled_model ):
126135 mock_args = _get_mock_args (infile = pickled_model , indent = 0 , language = "c_sharp" )
127136 generated_code = cli .generate_code (mock_args ).strip ()
You can’t perform that action at this time.
0 commit comments