|
1 | 1 | # -*- coding: utf-8 -*-
|
2 | 2 | """PyMilo modules."""
|
| 3 | +import os |
| 4 | +import re |
3 | 5 | import json
|
4 | 6 | from copy import deepcopy
|
5 | 7 | from warnings import warn
|
6 | 8 | from traceback import format_exc
|
7 | 9 | from .utils.util import get_sklearn_type, download_model
|
| 10 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
8 | 11 | from .pymilo_func import get_sklearn_data, get_sklearn_version, to_sklearn_model
|
9 | 12 | from .exceptions.serialize_exception import PymiloSerializationException, SerializationErrorTypes
|
10 | 13 | from .exceptions.deserialize_exception import PymiloDeserializationException, DeserializationErrorTypes
|
11 |
| -from .pymilo_param import PYMILO_VERSION, UNEQUAL_PYMILO_VERSIONS, UNEQUAL_SKLEARN_VERSIONS, INVALID_IMPORT_INIT_PARAMS |
| 14 | +from .pymilo_param import PYMILO_VERSION, UNEQUAL_PYMILO_VERSIONS, UNEQUAL_SKLEARN_VERSIONS |
| 15 | +from .pymilo_param import INVALID_IMPORT_INIT_PARAMS, BATCH_IMPORT_INVALID_DIRECTORY |
12 | 16 |
|
13 | 17 |
|
14 | 18 | class Export:
|
@@ -73,6 +77,45 @@ def to_json(self):
|
73 | 77 | "model_type": self.type},
|
74 | 78 | })
|
75 | 79 |
|
| 80 | + @staticmethod |
| 81 | + def batch_export(models, file_addr, run_parallel=False): |
| 82 | + """ |
| 83 | + Export a batch of models to individual JSON files in a specified directory. |
| 84 | +
|
| 85 | + This method takes a list of trained models and exports each one into a JSON file. The models |
| 86 | + are exported concurrently using multiple threads, where each model is saved to a file named |
| 87 | + 'model_{index}.json' in the provided directory. |
| 88 | +
|
| 89 | + :param models: list of models to get exported. |
| 90 | + :type models: list |
| 91 | + :param file_addr: the directory where exported JSON files will be saved. |
| 92 | + :type file_addr: str |
| 93 | + :param run_parallel: flag indicating the parallel execution of exports |
| 94 | + :type run_parallel: boolean |
| 95 | + :return: the count of models exported successfully |
| 96 | + """ |
| 97 | + if not os.path.exists(file_addr): |
| 98 | + os.mkdir(file_addr) |
| 99 | + |
| 100 | + def export_model(model, index): |
| 101 | + try: |
| 102 | + Export(model).save(file_adr=os.path.join(file_addr, f"model_{index}.json")) |
| 103 | + return 1 |
| 104 | + except Exception as _: |
| 105 | + return 0 |
| 106 | + if run_parallel: |
| 107 | + with ThreadPoolExecutor() as executor: |
| 108 | + futures = [executor.submit(export_model, model, index) for index, model in enumerate(models)] |
| 109 | + count = 0 |
| 110 | + for future in as_completed(futures): |
| 111 | + count += future.result() |
| 112 | + return count |
| 113 | + else: |
| 114 | + count = 0 |
| 115 | + for index, model in enumerate(models): |
| 116 | + count += export_model(model, index) |
| 117 | + return count |
| 118 | + |
76 | 119 |
|
77 | 120 | class Import:
|
78 | 121 | """
|
@@ -138,3 +181,61 @@ def to_model(self):
|
138 | 181 | :return: sklearn model
|
139 | 182 | """
|
140 | 183 | return to_sklearn_model(self)
|
| 184 | + |
| 185 | + @staticmethod |
| 186 | + def batch_import(file_addr, run_parallel=False): |
| 187 | + """ |
| 188 | + Import a batch of models from individual JSON files in a specified directory. |
| 189 | +
|
| 190 | + This method takes a directory containing JSON files and imports each one into a model. |
| 191 | + The models are imported concurrently using multiple threads, ensuring that the files are |
| 192 | + processed in the order determined by their numeric suffixes. The function returns the |
| 193 | + successfully imported models in the same order as their filenames. |
| 194 | +
|
| 195 | + :param file_addr: the directory where the JSON files to be imported are located. |
| 196 | + :type file_addr: str |
| 197 | + :param run_parallel: flag indicating the parallel execution of imports |
| 198 | + :type run_parallel: boolean |
| 199 | + :return: a tuple containing the count of models imported successfully and a list of the |
| 200 | + imported models in their filename order. |
| 201 | + """ |
| 202 | + if not os.path.exists(file_addr): |
| 203 | + raise FileNotFoundError(BATCH_IMPORT_INVALID_DIRECTORY) |
| 204 | + |
| 205 | + json_files = [f for f in os.listdir(file_addr) if f.endswith('.json')] |
| 206 | + json_files.sort(key=lambda x: int(re.search(r'_(\d+)\.json$', x).group(1))) |
| 207 | + |
| 208 | + models = [None] * len(json_files) |
| 209 | + count = 0 |
| 210 | + |
| 211 | + def import_model(file_path, index): |
| 212 | + try: |
| 213 | + model = Import(file_path).to_model() |
| 214 | + return index, model |
| 215 | + except Exception as _: |
| 216 | + return index, None |
| 217 | + |
| 218 | + if run_parallel: |
| 219 | + with ThreadPoolExecutor() as executor: |
| 220 | + futures = { |
| 221 | + executor.submit( |
| 222 | + import_model, |
| 223 | + os.path.join( |
| 224 | + file_addr, |
| 225 | + file), |
| 226 | + index): index for index, |
| 227 | + file in enumerate(json_files)} |
| 228 | + for future in as_completed(futures): |
| 229 | + index, model = future.result() |
| 230 | + if model is not None: |
| 231 | + models[index] = model |
| 232 | + count += 1 |
| 233 | + return count, models |
| 234 | + else: |
| 235 | + count = 0 |
| 236 | + for index, file in enumerate(json_files): |
| 237 | + model = Import(os.path.join(file_addr, file)).to_model() |
| 238 | + if model is not None: |
| 239 | + models[index] = model |
| 240 | + count += 1 |
| 241 | + return count, models |
0 commit comments