|
3 | 3 | import os |
4 | 4 | import logging |
5 | 5 | from datetime import datetime |
6 | | -from multiprocessing import Process, Queue |
| 6 | +from multiprocessing import Pool, Process, Queue |
| 7 | +from multiprocessing import cpu_count |
| 8 | +from functools import partial |
7 | 9 | from queue import Empty as EmptyQueueException |
8 | 10 | import tornado.ioloop |
9 | 11 | import tornado.web |
@@ -117,37 +119,50 @@ def make_app(data_queue): |
117 | 119 | ] |
118 | 120 | ) |
119 | 121 |
|
| 122 | +def train_individual_model(predictor_model, initial_run): |
| 123 | + metric_to_predict = predictor_model.metric |
| 124 | + pc = PrometheusConnect( |
| 125 | + url=Configuration.prometheus_url, |
| 126 | + headers=Configuration.prom_connect_headers, |
| 127 | + disable_ssl=True, |
| 128 | + ) |
120 | 129 |
|
121 | | -def train_model(initial_run=False, data_queue=None): |
122 | | - """Train the machine learning model.""" |
123 | | - for predictor_model in PREDICTOR_MODEL_LIST: |
124 | | - metric_to_predict = predictor_model.metric |
125 | | - data_start_time = datetime.now() - Configuration.metric_chunk_size |
126 | | - if initial_run: |
127 | | - data_start_time = ( |
128 | | - datetime.now() - Configuration.rolling_training_window_size |
129 | | - ) |
130 | | - |
131 | | - # Download new metric data from prometheus |
132 | | - new_metric_data = pc.get_metric_range_data( |
133 | | - metric_name=metric_to_predict.metric_name, |
134 | | - label_config=metric_to_predict.label_config, |
135 | | - start_time=data_start_time, |
136 | | - end_time=datetime.now(), |
137 | | - )[0] |
138 | | - |
139 | | - # Train the new model |
140 | | - start_time = datetime.now() |
141 | | - predictor_model.train( |
142 | | - new_metric_data, Configuration.retraining_interval_minutes |
143 | | - ) |
144 | | - _LOGGER.info( |
145 | | - "Total Training time taken = %s, for metric: %s %s", |
146 | | - str(datetime.now() - start_time), |
147 | | - metric_to_predict.metric_name, |
148 | | - metric_to_predict.label_config, |
| 130 | + data_start_time = datetime.now() - Configuration.metric_chunk_size |
| 131 | + if initial_run: |
| 132 | + data_start_time = ( |
| 133 | + datetime.now() - Configuration.rolling_training_window_size |
149 | 134 | ) |
150 | 135 |
|
| 136 | + # Download new metric data from prometheus |
| 137 | + new_metric_data = pc.get_metric_range_data( |
| 138 | + metric_name=metric_to_predict.metric_name, |
| 139 | + label_config=metric_to_predict.label_config, |
| 140 | + start_time=data_start_time, |
| 141 | + end_time=datetime.now(), |
| 142 | + )[0] |
| 143 | + |
| 144 | + # Train the new model |
| 145 | + start_time = datetime.now() |
| 146 | + predictor_model.train( |
| 147 | + new_metric_data, Configuration.retraining_interval_minutes) |
| 148 | + |
| 149 | + _LOGGER.info( |
| 150 | + "Total Training time taken = %s, for metric: %s %s", |
| 151 | + str(datetime.now() - start_time), |
| 152 | + metric_to_predict.metric_name, |
| 153 | + metric_to_predict.label_config, |
| 154 | + ) |
| 155 | + return predictor_model |
| 156 | + |
| 157 | +def train_model(initial_run=False, data_queue=None): |
| 158 | + """Train the machine learning model.""" |
| 159 | + global PREDICTOR_MODEL_LIST |
| 160 | + parallelism = min(Configuration.parallelism, cpu_count()) |
| 161 | + _LOGGER.info(f"Training models using ProcessPool of size:{parallelism}") |
| 162 | + training_partial = partial(train_individual_model, initial_run=initial_run) |
| 163 | + with Pool(parallelism) as p: |
| 164 | + result = p.map(training_partial, PREDICTOR_MODEL_LIST) |
| 165 | + PREDICTOR_MODEL_LIST = result |
151 | 166 | data_queue.put(PREDICTOR_MODEL_LIST) |
152 | 167 |
|
153 | 168 |
|
|
0 commit comments