Skip to content

Commit 57efefd

Browse files
authored
Add/batch operation (#146)
* draft of `batch_export` function added to `Export` * draft of `batch_import` function added to `Import` * testcase for batch operation added * `CHANGELOG.md` updated * `autopep8.sh` run * fix index bug * apply `Codacy` feedback * feedback applied + `autopep8.sh` executed * lowercasing the first letter * `CHANGELOG.md` updated
1 parent ee1e49e commit 57efefd

File tree

4 files changed

+133
-2
lines changed

4 files changed

+133
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
66

77
## [Unreleased]
88
### Added
9+
- batch operation testcases
10+
- `batch_export` function in `pymilo/pymilo_obj.py`
11+
- `batch_import` function in `pymilo/pymilo_obj.py`
912
- `CCA` model
1013
- `PLSCanonical` model
1114
- `PLSRegression` model

pymilo/pymilo_obj.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
# -*- coding: utf-8 -*-
22
"""PyMilo modules."""
3+
import os
4+
import re
35
import json
46
from copy import deepcopy
57
from warnings import warn
68
from traceback import format_exc
79
from .utils.util import get_sklearn_type, download_model
10+
from concurrent.futures import ThreadPoolExecutor, as_completed
811
from .pymilo_func import get_sklearn_data, get_sklearn_version, to_sklearn_model
912
from .exceptions.serialize_exception import PymiloSerializationException, SerializationErrorTypes
1013
from .exceptions.deserialize_exception import PymiloDeserializationException, DeserializationErrorTypes
11-
from .pymilo_param import PYMILO_VERSION, UNEQUAL_PYMILO_VERSIONS, UNEQUAL_SKLEARN_VERSIONS, INVALID_IMPORT_INIT_PARAMS
14+
from .pymilo_param import PYMILO_VERSION, UNEQUAL_PYMILO_VERSIONS, UNEQUAL_SKLEARN_VERSIONS
15+
from .pymilo_param import INVALID_IMPORT_INIT_PARAMS, BATCH_IMPORT_INVALID_DIRECTORY
1216

1317

1418
class Export:
@@ -73,6 +77,45 @@ def to_json(self):
7377
"model_type": self.type},
7478
})
7579

80+
@staticmethod
81+
def batch_export(models, file_addr, run_parallel=False):
82+
"""
83+
Export a batch of models to individual JSON files in a specified directory.
84+
85+
This method takes a list of trained models and exports each one into a JSON file. The models
86+
are exported concurrently using multiple threads, where each model is saved to a file named
87+
'model_{index}.json' in the provided directory.
88+
89+
:param models: list of models to get exported.
90+
:type models: list
91+
:param file_addr: the directory where exported JSON files will be saved.
92+
:type file_addr: str
93+
:param run_parallel: flag indicating the parallel execution of exports
94+
:type run_parallel: boolean
95+
:return: the count of models exported successfully
96+
"""
97+
if not os.path.exists(file_addr):
98+
os.mkdir(file_addr)
99+
100+
def export_model(model, index):
101+
try:
102+
Export(model).save(file_adr=os.path.join(file_addr, f"model_{index}.json"))
103+
return 1
104+
except Exception as _:
105+
return 0
106+
if run_parallel:
107+
with ThreadPoolExecutor() as executor:
108+
futures = [executor.submit(export_model, model, index) for index, model in enumerate(models)]
109+
count = 0
110+
for future in as_completed(futures):
111+
count += future.result()
112+
return count
113+
else:
114+
count = 0
115+
for index, model in enumerate(models):
116+
count += export_model(model, index)
117+
return count
118+
76119

77120
class Import:
78121
"""
@@ -138,3 +181,61 @@ def to_model(self):
138181
:return: sklearn model
139182
"""
140183
return to_sklearn_model(self)
184+
185+
@staticmethod
186+
def batch_import(file_addr, run_parallel=False):
187+
"""
188+
Import a batch of models from individual JSON files in a specified directory.
189+
190+
This method takes a directory containing JSON files and imports each one into a model.
191+
The models are imported concurrently using multiple threads, ensuring that the files are
192+
processed in the order determined by their numeric suffixes. The function returns the
193+
successfully imported models in the same order as their filenames.
194+
195+
:param file_addr: the directory where the JSON files to be imported are located.
196+
:type file_addr: str
197+
:param run_parallel: flag indicating the parallel execution of imports
198+
:type run_parallel: boolean
199+
:return: a tuple containing the count of models imported successfully and a list of the
200+
imported models in their filename order.
201+
"""
202+
if not os.path.exists(file_addr):
203+
raise FileNotFoundError(BATCH_IMPORT_INVALID_DIRECTORY)
204+
205+
json_files = [f for f in os.listdir(file_addr) if f.endswith('.json')]
206+
json_files.sort(key=lambda x: int(re.search(r'_(\d+)\.json$', x).group(1)))
207+
208+
models = [None] * len(json_files)
209+
count = 0
210+
211+
def import_model(file_path, index):
212+
try:
213+
model = Import(file_path).to_model()
214+
return index, model
215+
except Exception as _:
216+
return index, None
217+
218+
if run_parallel:
219+
with ThreadPoolExecutor() as executor:
220+
futures = {
221+
executor.submit(
222+
import_model,
223+
os.path.join(
224+
file_addr,
225+
file),
226+
index): index for index,
227+
file in enumerate(json_files)}
228+
for future in as_completed(futures):
229+
index, model = future.result()
230+
if model is not None:
231+
models[index] = model
232+
count += 1
233+
return count, models
234+
else:
235+
count = 0
236+
for index, file in enumerate(json_files):
237+
model = Import(os.path.join(file_addr, file)).to_model()
238+
if model is not None:
239+
models[index] = model
240+
count += 1
241+
return count, models

pymilo/pymilo_param.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
INVALID_IMPORT_INIT_PARAMS = "Invalid input parameters, you should either pass a valid file_adr or a json_dump or a url to initiate Import class."
9292
DOWNLOAD_MODEL_FAILED = "Failed to download the JSON file, Server didn't respond."
9393
INVALID_DOWNLOADED_MODEL = "The downloaded content is not a valid JSON file."
94-
94+
BATCH_IMPORT_INVALID_DIRECTORY = "The given directory does not exist."
9595

9696
SKLEARN_LINEAR_MODEL_TABLE = {
9797
"DummyRegressor": dummy.DummyRegressor,
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import os
2+
import re
3+
import random
4+
import numpy as np
5+
from pymilo import Export, Import
6+
from sklearn.metrics import mean_squared_error
7+
from sklearn.linear_model import LinearRegression
8+
from pymilo.utils.data_exporter import prepare_simple_regression_datasets
9+
10+
11+
def test_batch_execution():
12+
x_train, y_train, x_test, _ = prepare_simple_regression_datasets()
13+
linear_regression = LinearRegression()
14+
linear_regression.fit(x_train, y_train)
15+
pre_models = [linear_regression]*100
16+
exp_n = Export.batch_export(pre_models, os.getcwd())
17+
imp_n, post_models = Import.batch_import(os.getcwd())
18+
r_index = random.randint(0, len(post_models) - 1)
19+
pre_result = pre_models[r_index].predict(x_test)
20+
post_result = post_models[r_index].predict(x_test)
21+
mse = mean_squared_error(post_result, pre_result)
22+
pattern = re.compile(r'model_\d+\.json')
23+
for filename in os.listdir(os.getcwd()):
24+
if pattern.match(filename):
25+
file_path = os.path.join(os.getcwd(), filename)
26+
os.remove(file_path)
27+
return exp_n == imp_n and np.abs(mse) <= 10**(-8)

0 commit comments

Comments
 (0)