Skip to content

Commit f0d15ac

Browse files
committed
Added the CatBoost Model
1 parent 4a36d14 commit f0d15ac

File tree

7 files changed

+317
-11
lines changed

7 files changed

+317
-11
lines changed

conf/base/catalog.yml

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ train_test_split_visualization:
132132
# Trained xgboost model
133133
xgboost_model:
134134
type: pickle.PickleDataset
135-
filepath: data/05_model_output/xgboost_model.pkl
135+
filepath: data/05_model_output/xgboost/xgboost_model.pkl
136136
backend: pickle
137137
metadata:
138138
kedro-viz:
@@ -141,7 +141,7 @@ xgboost_model:
141141
# Trained Random Forest model
142142
random_forest_model:
143143
type: pickle.PickleDataset
144-
filepath: data/05_model_output/random_forest_model.pkl
144+
filepath: data/05_model_output/random_forest/random_forest_model.pkl
145145
backend: pickle
146146
metadata:
147147
kedro-viz:
@@ -150,7 +150,7 @@ random_forest_model:
150150
# Trained LightGBM model
151151
lightgbm_model:
152152
type: pickle.PickleDataset
153-
filepath: data/05_model_output/lightgbm_model.pkl
153+
filepath: data/05_model_output/lightgbm/lightgbm_model.pkl
154154
backend: pickle
155155
metadata:
156156
kedro-viz:
@@ -159,7 +159,7 @@ lightgbm_model:
159159
# xgboost feature importance plot
160160
xgboost_feature_importance_plot:
161161
type: matplotlib.MatplotlibWriter
162-
filepath: data/04_reporting/xgboost_feature_importance_plot.png
162+
filepath: data/04_reporting/xgboost/xgboost_feature_importance_plot.png
163163
save_args:
164164
format: png
165165
metadata:
@@ -169,7 +169,7 @@ xgboost_feature_importance_plot:
169169
# lightgbm feature importance plot
170170
lightgbm_feature_importance_plot:
171171
type: matplotlib.MatplotlibWriter
172-
filepath: data/04_reporting/lightgbm_feature_importance_plot.png
172+
filepath: data/04_reporting/lightgbm/lightgbm_feature_importance_plot.png
173173
save_args:
174174
format: png
175175
metadata:
@@ -179,7 +179,7 @@ lightgbm_feature_importance_plot:
179179
# Random Forest feature importance plot
180180
random_forest_feature_importance_plot:
181181
type: matplotlib.MatplotlibWriter
182-
filepath: data/04_reporting/random_forest_feature_importance_plot.png
182+
filepath: data/04_reporting/random_forest/random_forest_feature_importance_plot.png
183183
save_args:
184184
format: png
185185
metadata:
@@ -189,7 +189,7 @@ random_forest_feature_importance_plot:
189189
# real_data_and_xgboost_predictions_plot
190190
real_data_and_xgboost_predictions_plot:
191191
type: matplotlib.MatplotlibWriter
192-
filepath: data/04_reporting/real_data_and_xgboost_predictions_plot.png
192+
filepath: data/04_reporting/xgboost/real_data_and_xgboost_predictions_plot.png
193193
save_args:
194194
format: png
195195
metadata:
@@ -199,7 +199,7 @@ real_data_and_xgboost_predictions_plot:
199199
# real_data_and_lightgbm_predictions_plot
200200
real_data_and_lightgbm_predictions_plot:
201201
type: matplotlib.MatplotlibWriter
202-
filepath: data/04_reporting/real_data_and_lightgbm_predictions_plot.png
202+
filepath: data/04_reporting/lightgbm/real_data_and_lightgbm_predictions_plot.png
203203
save_args:
204204
format: png
205205
metadata:
@@ -209,7 +209,56 @@ real_data_and_lightgbm_predictions_plot:
209209
# real_data_and_rf_predictions_plot
210210
real_data_and_rf_predictions_plot:
211211
type: matplotlib.MatplotlibWriter
212-
filepath: data/04_reporting/real_data_and_rf_predictions_plot.png
212+
filepath: data/04_reporting/random_forest/real_data_and_rf_predictions_plot.png
213+
save_args:
214+
format: png
215+
metadata:
216+
kedro-viz:
217+
layer: reporting
218+
219+
# Trained CatBoost model
220+
catboost_model:
221+
type: pickle.PickleDataset
222+
filepath: data/05_model_output/catboost/catboost_model.pkl
223+
backend: pickle
224+
metadata:
225+
kedro-viz:
226+
layer: model
227+
228+
# CatBoost feature importance plot
229+
catboost_feature_importance_plot:
230+
type: matplotlib.MatplotlibWriter
231+
filepath: data/04_reporting/catboost/catboost_feature_importance_plot.png
232+
save_args:
233+
format: png
234+
metadata:
235+
kedro-viz:
236+
layer: reporting
237+
238+
# Real data and CatBoost predictions plot
239+
real_data_and_catboost_predictions_plot:
240+
type: matplotlib.MatplotlibWriter
241+
filepath: data/04_reporting/catboost/real_data_and_catboost_predictions_plot.png
242+
save_args:
243+
format: png
244+
metadata:
245+
kedro-viz:
246+
layer: reporting
247+
248+
# CatBoost SHAP summary plot for model explainability
249+
catboost_shap_summary_plot:
250+
type: matplotlib.MatplotlibWriter
251+
filepath: data/04_reporting/catboost/catboost_shap_summary_plot.png
252+
save_args:
253+
format: png
254+
metadata:
255+
kedro-viz:
256+
layer: reporting
257+
258+
# CatBoost Partial Dependence Plot
259+
catboost_partial_dependence_plot:
260+
type: matplotlib.MatplotlibWriter
261+
filepath: data/04_reporting/catboost/catboost_partial_dependence_plot.png
213262
save_args:
214263
format: png
215264
metadata:
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Parameters specific to the CatBoost model
2+
catboost_training_pipeline.catboost_model_params:
3+
target: total_consumption
4+
threshold: 2010-05-17
5+
iterations: 1000
6+
depth: 3
7+
learning_rate: 0.01
8+
loss_function: 'RMSE'
9+
eval_metric: 'RMSE'
10+
verbose_eval: 100
11+
random_state: 42
12+
data_types:
13+
boolean_columns:
14+
- 'is_holiday'
15+
- 'conditions_clear'
16+
- 'conditions_overcast'
17+
- 'conditions_partiallycloudy'
18+
- 'conditions_rain'
19+
- 'conditions_rainovercast'
20+
- 'conditions_rainpartiallycloudy'
21+
- 'conditions_snowovercast'
22+
- 'conditions_snowpartiallycloudy'
23+
- 'conditions_snowrain'
24+
- 'conditions_snowrainovercast'
25+
- 'conditions_snowrainpartiallycloudy'
26+
27+
# Features for Partial Dependence Plot (PDP)
28+
catboost_training_pipeline.catboost_pdp_features:
29+
- 'temp_lag_1'
30+
- 'dayofweek'
31+
- 'total_consumption_lag_1'
32+
- 'tempmax_lag_1'
33+
- 'feelslike_lag_1'
34+
- 'tempmin_lag_1'
35+
- 'tempmin_rolling_mean_3'
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""
2+
This is a boilerplate pipeline 'catboost_pipeline'
3+
generated using Kedro 0.19.10
4+
"""
5+
6+
from .pipeline import create_pipeline
7+
8+
__all__ = ["create_pipeline"]
9+
10+
__version__ = "0.1"
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import logging
2+
import matplotlib.pyplot as plt
3+
from catboost import CatBoostRegressor
4+
from shap import TreeExplainer, summary_plot
5+
from sklearn.inspection import partial_dependence, PartialDependenceDisplay
6+
7+
def train_catboost_model(X_train, y_train, params):
8+
"""
9+
Trains a CatBoost regression model using the given training data and parameters.
10+
11+
Args:
12+
X_train (DataFrame): Training features.
13+
y_train (Series/DataFrame): Training target.
14+
params (dict): Dictionary containing CatBoost parameters.
15+
16+
Returns:
17+
CatBoostRegressor: Trained CatBoost model.
18+
"""
19+
# Initialize logger
20+
logger = logging.getLogger(__name__)
21+
logger.info("Starting CatBoost model training...")
22+
23+
# Log model parameters for reproducibility/debugging
24+
logger.info(
25+
f"CatBoost parameters: iterations={params['iterations']}, "
26+
f"depth={params['depth']}, learning_rate={params['learning_rate']}, "
27+
f"loss_function={params['loss_function']}, eval_metric={params['eval_metric']}, "
28+
f"verbose_eval={params.get('verbose_eval', True)}"
29+
)
30+
31+
# Convert specified columns to boolean based on provided data types
32+
boolean_columns = params.get("data_types", {}).get("boolean_columns", [])
33+
for col in boolean_columns:
34+
if col in X_train.columns:
35+
logger.debug(f"Converting column {col} to boolean.")
36+
X_train[col] = X_train[col].astype("bool")
37+
38+
# Drop rows with null values from X_train and align y_train
39+
X_train_clean = X_train.dropna()
40+
y_train_clean = y_train.loc[X_train_clean.index]
41+
42+
# Log how many rows were dropped
43+
dropped_rows = X_train.shape[0] - X_train_clean.shape[0]
44+
logger.info(f"Dropped {dropped_rows} rows with null values from training data.")
45+
46+
# Instantiate CatBoostRegressor with the given parameters
47+
cat_model = CatBoostRegressor(
48+
allow_writing_files=False,
49+
iterations=params["iterations"],
50+
depth=params["depth"],
51+
learning_rate=params["learning_rate"],
52+
loss_function=params["loss_function"],
53+
eval_metric=params["eval_metric"],
54+
random_seed=params.get("random_state", 42),
55+
verbose=params.get("verbose_eval", True)
56+
)
57+
58+
# Train the model with an evaluation set for monitoring
59+
logger.info("Training the CatBoost model...")
60+
cat_model.fit(
61+
X_train_clean,
62+
y_train_clean,
63+
eval_set=[(X_train_clean, y_train_clean)],
64+
verbose=params.get("verbose_eval", True)
65+
)
66+
67+
# Log the completion of the training process
68+
logger.info("CatBoost model training completed successfully.")
69+
70+
return cat_model
71+
72+
73+
def explain_catboost_model(model, X_train):
74+
"""
75+
Generates a SHAP summary plot for the given CatBoost model.
76+
77+
Args:
78+
model: Trained CatBoost model.
79+
X_train: Training feature set.
80+
81+
Returns:
82+
A matplotlib figure object with the SHAP summary plot.
83+
"""
84+
logger = logging.getLogger(__name__)
85+
logger.info("Computing SHAP values for CatBoost model...")
86+
87+
# Create a SHAP explainer and compute SHAP values
88+
explainer = TreeExplainer(model)
89+
shap_values = explainer.shap_values(X_train)
90+
91+
# Create a SHAP summary plot
92+
plt.figure()
93+
summary_plot(shap_values, X_train, show=False)
94+
fig = plt.gcf() # Get current figure
95+
plt.close(fig)
96+
97+
logger.info("SHAP summary plot created successfully.")
98+
return fig
99+
100+
101+
def plot_partial_dependence_catboost(model, X_train, features):
102+
"""
103+
Generates a partial dependence plot for specified features using the trained CatBoost model.
104+
105+
Args:
106+
model: Trained CatBoost model.
107+
X_train (DataFrame): Training features.
108+
features (list): List of feature names or indices for which to compute partial dependence.
109+
110+
Returns:
111+
matplotlib.figure.Figure: The partial dependence plot figure.
112+
"""
113+
logger = logging.getLogger(__name__)
114+
logger.info("Creating partial dependence plot for CatBoost model...")
115+
116+
# Create the plot using scikit-learn's PartialDependenceDisplay
117+
fig, ax = plt.subplots(figsize=(12, 8))
118+
display = PartialDependenceDisplay.from_estimator(
119+
model,
120+
X_train,
121+
features=features,
122+
ax=ax
123+
)
124+
ax.set_title("Partial Dependence Plot")
125+
plt.tight_layout()
126+
plt.close(fig)
127+
128+
logger.info("Partial dependence plot created successfully.")
129+
return fig
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from kedro.pipeline import Pipeline, node, pipeline
2+
3+
from .nodes import (
4+
train_catboost_model,
5+
explain_catboost_model,
6+
plot_partial_dependence_catboost
7+
)
8+
9+
from ..random_forest_pipeline.nodes import (
10+
plot_real_data_and_predictions_with_train,
11+
plot_feature_importance,
12+
generate_predictions,
13+
)
14+
15+
def create_pipeline(**kwargs) -> Pipeline:
16+
return pipeline(
17+
[
18+
node( # Node 1: Train CatBoost Model
19+
func=train_catboost_model,
20+
inputs=[
21+
"X_train",
22+
"y_train",
23+
"params:catboost_model_params",
24+
],
25+
outputs="catboost_model",
26+
name="train_catboost_model_node",
27+
tags=["model_training", "catboost"],
28+
),
29+
node( # Node 2: Plot Feature Importance
30+
func=plot_feature_importance,
31+
inputs=[
32+
"catboost_model",
33+
"X_train",
34+
],
35+
outputs="catboost_feature_importance_plot",
36+
name="plot_feature_importance_node",
37+
tags=["feature_importance", "visualization", "catboost", "model_training"],
38+
),
39+
node( # Node 3: Generate Predictions
40+
func=generate_predictions,
41+
inputs=[
42+
"X_test",
43+
"catboost_model",
44+
],
45+
outputs="catboost_model_predictions",
46+
name="generate_predictions_node",
47+
tags=["predictions", "catboost", "model_training"],
48+
),
49+
node( # Node 4: Plot Real Data and Predictions
50+
func=plot_real_data_and_predictions_with_train,
51+
inputs=[
52+
"y_train",
53+
"y_test",
54+
"catboost_model_predictions",
55+
],
56+
outputs="real_data_and_catboost_predictions_plot",
57+
name="plot_real_data_and_predictions_node",
58+
tags=["data_visualization", "catboost", "model_training"],
59+
),
60+
node( # Node 5: Generate SHAP Summary Plot for Explainability
61+
func=explain_catboost_model,
62+
inputs=["catboost_model", "X_train"],
63+
outputs="catboost_shap_summary_plot",
64+
name="explain_catboost_model_node",
65+
tags=["explainability", "catboost", "model_training"],
66+
),
67+
node( # Node 7: Plot Partial Dependence
68+
func=plot_partial_dependence_catboost,
69+
inputs=["catboost_model", "X_train", "params:catboost_pdp_features"],
70+
outputs="catboost_partial_dependence_plot",
71+
name="plot_partial_dependence_node",
72+
tags=["partial_dependence", "catboost", "model_training"],
73+
),
74+
],
75+
tags="model_training",
76+
namespace="catboost_training_pipeline",
77+
inputs=["X_train", "y_train", "X_test", "y_test"],
78+
outputs=[
79+
"catboost_model",
80+
"catboost_feature_importance_plot",
81+
"real_data_and_catboost_predictions_plot",
82+
"catboost_shap_summary_plot",
83+
"catboost_partial_dependence_plot",
84+
],
85+
)

0 commit comments

Comments
 (0)