|
4 | 4 | import logging |
5 | 5 | from datetime import datetime |
6 | 6 | from multiprocessing import Pool, Process, Queue |
| 7 | +from multiprocessing import cpu_count |
7 | 8 | from functools import partial |
8 | 9 | from queue import Empty as EmptyQueueException |
9 | 10 | import tornado.ioloop |
@@ -156,16 +157,12 @@ def train_individual_model(predictor_model, initial_run): |
156 | 157 | def train_model(initial_run=False, data_queue=None): |
157 | 158 | """Train the machine learning model.""" |
158 | 159 | global PREDICTOR_MODEL_LIST |
159 | | - if Configuration.parallelism_required: |
160 | | - _LOGGER.info("Training models concurrently using ProcessPool") |
161 | | - training_partial = partial(train_individual_model, initial_run=initial_run) |
162 | | - with Pool() as p: |
163 | | - result = p.map(training_partial, PREDICTOR_MODEL_LIST) |
164 | | - PREDICTOR_MODEL_LIST = result |
165 | | - else: |
166 | | - _LOGGER.info("Training models sequentially") |
167 | | - for predictor_model in PREDICTOR_MODEL_LIST: |
168 | | - model = train_individual_model(predictor_model, initial_run) |
| 160 | + parallelism = min(Configuration.parallelism, cpu_count()) |
| 161 | + _LOGGER.info(f"Training models using ProcessPool of size:{parallelism}") |
| 162 | + training_partial = partial(train_individual_model, initial_run=initial_run) |
| 163 | + with Pool(parallelism) as p: |
| 164 | + result = p.map(training_partial, PREDICTOR_MODEL_LIST) |
| 165 | + PREDICTOR_MODEL_LIST = result |
169 | 166 | data_queue.put(PREDICTOR_MODEL_LIST) |
170 | 167 |
|
171 | 168 |
|
|
0 commit comments