-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
56 lines (44 loc) · 1.37 KB
/
train.py
File metadata and controls
56 lines (44 loc) · 1.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
Main module to train, store, and validate the classifier.
"""
import hydra
import logging
from training.modelling.data import data_cleaning_stage
from training.train_model import get_train_test_data, train_save_model
from training.val_model import val_model
from omegaconf import DictConfig
_steps = [
"data_cleaning",
"train_model",
"check_score"
]
@hydra.main(config_name="config.yml")
def go(config: DictConfig):
"""
Run pipeline stages.
Parameters
----------
config : DictConfig
Configurations.
"""
logging.basicConfig(level=logging.INFO)
root_path = hydra.utils.get_original_cwd()
# Steps to execute
steps_par = config['main']['steps']
active_steps = steps_par.split(",") if steps_par != "all" else _steps
cat_features = config['data']['cat_features']
if "data_cleaning" in active_steps:
logging.info("Cleaning and saving raw_data")
data_cleaning_stage(root_path)
train_df, test_df = get_train_test_data(root_path)
if "train_model" in active_steps:
logging.info("Train/Test model procedure started")
train_save_model(train_df, cat_features, root_path)
if "check_score" in active_steps:
logging.info("Score check procedure started")
val_model(test_df, cat_features, root_path)
if __name__ == "__main__":
"""
Main entrypoint
"""
go()