|
3 | 3 | import os |
4 | 4 | import logging |
5 | 5 | from datetime import datetime |
6 | | -from multiprocessing import Process, Queue |
| 6 | +from multiprocessing import Pool, Process, Queue |
| 7 | +from functools import partial |
7 | 8 | from queue import Empty as EmptyQueueException |
8 | 9 | import tornado.ioloop |
9 | 10 | import tornado.web |
@@ -117,37 +118,54 @@ def make_app(data_queue): |
117 | 118 | ] |
118 | 119 | ) |
119 | 120 |
|
| 121 | +def train_individual_model(predictor_model, initial_run): |
| 122 | + metric_to_predict = predictor_model.metric |
| 123 | + pc = PrometheusConnect( |
| 124 | + url=Configuration.prometheus_url, |
| 125 | + headers=Configuration.prom_connect_headers, |
| 126 | + disable_ssl=True, |
| 127 | + ) |
120 | 128 |
|
121 | | -def train_model(initial_run=False, data_queue=None): |
122 | | - """Train the machine learning model.""" |
123 | | - for predictor_model in PREDICTOR_MODEL_LIST: |
124 | | - metric_to_predict = predictor_model.metric |
125 | | - data_start_time = datetime.now() - Configuration.metric_chunk_size |
126 | | - if initial_run: |
127 | | - data_start_time = ( |
128 | | - datetime.now() - Configuration.rolling_training_window_size |
129 | | - ) |
130 | | - |
131 | | - # Download new metric data from prometheus |
132 | | - new_metric_data = pc.get_metric_range_data( |
133 | | - metric_name=metric_to_predict.metric_name, |
134 | | - label_config=metric_to_predict.label_config, |
135 | | - start_time=data_start_time, |
136 | | - end_time=datetime.now(), |
137 | | - )[0] |
138 | | - |
139 | | - # Train the new model |
140 | | - start_time = datetime.now() |
141 | | - predictor_model.train( |
142 | | - new_metric_data, Configuration.retraining_interval_minutes |
143 | | - ) |
144 | | - _LOGGER.info( |
145 | | - "Total Training time taken = %s, for metric: %s %s", |
146 | | - str(datetime.now() - start_time), |
147 | | - metric_to_predict.metric_name, |
148 | | - metric_to_predict.label_config, |
| 129 | + data_start_time = datetime.now() - Configuration.metric_chunk_size |
| 130 | + if initial_run: |
| 131 | + data_start_time = ( |
| 132 | + datetime.now() - Configuration.rolling_training_window_size |
149 | 133 | ) |
150 | 134 |
|
| 135 | + # Download new metric data from prometheus |
| 136 | + new_metric_data = pc.get_metric_range_data( |
| 137 | + metric_name=metric_to_predict.metric_name, |
| 138 | + label_config=metric_to_predict.label_config, |
| 139 | + start_time=data_start_time, |
| 140 | + end_time=datetime.now(), |
| 141 | + )[0] |
| 142 | + |
| 143 | + # Train the new model |
| 144 | + start_time = datetime.now() |
| 145 | + predictor_model.train( |
| 146 | + new_metric_data, Configuration.retraining_interval_minutes) |
| 147 | + |
| 148 | + _LOGGER.info( |
| 149 | + "Total Training time taken = %s, for metric: %s %s", |
| 150 | + str(datetime.now() - start_time), |
| 151 | + metric_to_predict.metric_name, |
| 152 | + metric_to_predict.label_config, |
| 153 | + ) |
| 154 | + return predictor_model |
| 155 | + |
| 156 | +def train_model(initial_run=False, data_queue=None): |
| 157 | + """Train the machine learning model.""" |
| 158 | + global PREDICTOR_MODEL_LIST |
| 159 | + if Configuration.parallelism_required: |
| 160 | + _LOGGER.info("Training models concurrently using ProcessPool") |
| 161 | + training_partial = partial(train_individual_model, initial_run=initial_run) |
| 162 | + with Pool() as p: |
| 163 | + result = p.map(training_partial, PREDICTOR_MODEL_LIST) |
| 164 | + PREDICTOR_MODEL_LIST = result |
| 165 | + else: |
| 166 | + _LOGGER.info("Training models sequentially") |
| 167 | + for predictor_model in PREDICTOR_MODEL_LIST: |
| 168 | + model = train_individual_model(predictor_model, initial_run) |
151 | 169 | data_queue.put(PREDICTOR_MODEL_LIST) |
152 | 170 |
|
153 | 171 |
|
|
0 commit comments