Skip to content
This repository was archived by the owner on Jun 7, 2023. It is now read-only.

Commit ef3f9ee

Browse files
adding a new env variable that provides an option between sequential and parallel model training
1 parent 20dc284 commit ef3f9ee

File tree

2 files changed

+52
-29
lines changed

2 files changed

+52
-29
lines changed

app.py

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import os
44
import logging
55
from datetime import datetime
6-
from multiprocessing import Process, Queue
6+
from multiprocessing import Pool, Process, Queue
7+
from functools import partial
78
from queue import Empty as EmptyQueueException
89
import tornado.ioloop
910
import tornado.web
@@ -117,37 +118,54 @@ def make_app(data_queue):
117118
]
118119
)
119120

121+
def train_individual_model(predictor_model, initial_run):
122+
metric_to_predict = predictor_model.metric
123+
pc = PrometheusConnect(
124+
url=Configuration.prometheus_url,
125+
headers=Configuration.prom_connect_headers,
126+
disable_ssl=True,
127+
)
120128

121-
def train_model(initial_run=False, data_queue=None):
122-
"""Train the machine learning model."""
123-
for predictor_model in PREDICTOR_MODEL_LIST:
124-
metric_to_predict = predictor_model.metric
125-
data_start_time = datetime.now() - Configuration.metric_chunk_size
126-
if initial_run:
127-
data_start_time = (
128-
datetime.now() - Configuration.rolling_training_window_size
129-
)
130-
131-
# Download new metric data from prometheus
132-
new_metric_data = pc.get_metric_range_data(
133-
metric_name=metric_to_predict.metric_name,
134-
label_config=metric_to_predict.label_config,
135-
start_time=data_start_time,
136-
end_time=datetime.now(),
137-
)[0]
138-
139-
# Train the new model
140-
start_time = datetime.now()
141-
predictor_model.train(
142-
new_metric_data, Configuration.retraining_interval_minutes
143-
)
144-
_LOGGER.info(
145-
"Total Training time taken = %s, for metric: %s %s",
146-
str(datetime.now() - start_time),
147-
metric_to_predict.metric_name,
148-
metric_to_predict.label_config,
129+
data_start_time = datetime.now() - Configuration.metric_chunk_size
130+
if initial_run:
131+
data_start_time = (
132+
datetime.now() - Configuration.rolling_training_window_size
149133
)
150134

135+
# Download new metric data from prometheus
136+
new_metric_data = pc.get_metric_range_data(
137+
metric_name=metric_to_predict.metric_name,
138+
label_config=metric_to_predict.label_config,
139+
start_time=data_start_time,
140+
end_time=datetime.now(),
141+
)[0]
142+
143+
# Train the new model
144+
start_time = datetime.now()
145+
predictor_model.train(
146+
new_metric_data, Configuration.retraining_interval_minutes)
147+
148+
_LOGGER.info(
149+
"Total Training time taken = %s, for metric: %s %s",
150+
str(datetime.now() - start_time),
151+
metric_to_predict.metric_name,
152+
metric_to_predict.label_config,
153+
)
154+
return predictor_model
155+
156+
def train_model(initial_run=False, data_queue=None):
157+
"""Train the machine learning model."""
158+
global PREDICTOR_MODEL_LIST
159+
if Configuration.parallelism_required:
160+
_LOGGER.info("Training models concurrently using ProcessPool")
161+
training_partial = partial(train_individual_model, initial_run=initial_run)
162+
with Pool() as p:
163+
result = p.map(training_partial, PREDICTOR_MODEL_LIST)
164+
PREDICTOR_MODEL_LIST = result
165+
else:
166+
_LOGGER.info("Training models sequentially")
167+
for predictor_model in PREDICTOR_MODEL_LIST:
168+
model = train_individual_model(predictor_model, initial_run)
151169
data_queue.put(PREDICTOR_MODEL_LIST)
152170

153171

configuration.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,8 @@ class Configuration:
5757
"Metric data rolling training window size: %s", rolling_training_window_size
5858
)
5959
_LOGGER.info("Model retraining interval: %s minutes", retraining_interval_minutes)
60+
61+
# An option for Parallelism.
62+
# Setting FLT_PARALLELISM to True will enable the useage of a process pool
63+
# during training.
64+
parallelism_required = bool(os.getenv("FLT_PARALLELISM", ""))

0 commit comments

Comments
 (0)