|
11 | 11 | from imblearn.over_sampling import BorderlineSMOTE |
12 | 12 | from xgboost import XGBClassifier |
13 | 13 |
|
14 | | -# ✅ Load Dataset |
15 | 14 | CSV_PATH = "/home/pavithra/k8s-failure-prediction/data/merged_data.csv" |
16 | 15 | df = pd.read_csv(CSV_PATH) |
17 | 16 |
|
18 | | -# ✅ Preprocessing |
19 | 17 | df.columns = df.columns.str.strip().str.replace(r'\s+', '_', regex=True).str.lower() |
20 | 18 | df["timestamp"] = pd.to_datetime(df["timestamp"]) |
21 | 19 | df.set_index("timestamp", inplace=True) |
22 | 20 |
|
23 | | -# ✅ Feature Engineering |
24 | 21 | for col in df.columns: |
25 | 22 | df[f"{col}_avg"] = df[col].rolling(window=5, min_periods=1).mean() |
26 | | - |
27 | | -# ✅ Target Variable |
28 | 23 | df["target"] = (df["container_restart_count"].diff().fillna(0) > 1).astype(int) |
29 | 24 | df.drop(columns=["container_restart_count"], inplace=True) |
30 | 25 |
|
31 | | -# ✅ Prepare Data |
32 | 26 | X = df.drop(columns=["target"]) |
33 | 27 | y = df["target"] |
34 | | - |
35 | | -# ✅ Handle Class Imbalance |
| 28 | +# to handle the imbalance |
36 | 29 | if y.value_counts().min() >= 5: |
37 | 30 | smote = BorderlineSMOTE(sampling_strategy='auto', random_state=42) |
38 | 31 | X_resampled, y_resampled = smote.fit_resample(X, y) |
39 | 32 | else: |
40 | 33 | X_resampled, y_resampled = X, y |
41 | 34 |
|
42 | | -# ✅ Train-Test Split |
| 35 | +#splitting |
43 | 36 | X_train, X_test, y_train, y_test = train_test_split(X_resampled, y_resampled, test_size=0.2, random_state=42) |
44 | 37 |
|
45 | | -# ✅ Reduce Overfitting (Final Fix) |
| 38 | +#to reduce overfitting |
46 | 39 | rf = RandomForestClassifier( |
47 | | - n_estimators=300, # More trees |
48 | | - max_depth=10, # Reduce tree depth |
49 | | - min_samples_split=20, # More samples needed per split |
50 | | - min_samples_leaf=10, # Prevent small branches |
| 40 | + n_estimators=300, |
| 41 | + max_depth=10, |
| 42 | + min_samples_split=20, |
| 43 | + min_samples_leaf=10, |
51 | 44 | bootstrap=True, |
52 | 45 | random_state=42 |
53 | 46 | ) |
54 | 47 |
|
55 | | -# ✅ Ensemble Model (Random Forest + XGBoost) |
56 | 48 | xgb = XGBClassifier(n_estimators=200, learning_rate=0.05, max_depth=7, subsample=0.8, colsample_bytree=0.8, random_state=42) |
57 | 49 | rf.fit(X_train, y_train) |
58 | 50 | xgb.fit(X_train, y_train) |
59 | 51 |
|
60 | | -# ✅ Predictions |
| 52 | +#predict |
61 | 53 | y_pred_rf = rf.predict(X_test) |
62 | 54 | y_pred_xgb = xgb.predict(X_test) |
63 | 55 |
|
64 | | -# ✅ Combine Predictions (Soft Voting) |
| 56 | +#combining them |
65 | 57 | y_pred_ensemble = (y_pred_rf + y_pred_xgb) // 2 |
66 | | - |
67 | | -# ✅ Evaluate Model |
68 | 58 | train_acc = rf.score(X_train, y_train) * 100 |
69 | 59 | test_acc = accuracy_score(y_test, y_pred_ensemble) * 100 |
70 | 60 | print(f"\n🎯 Train Accuracy: {train_acc:.2f} %") |
71 | 61 | print(f"🎯 Test Accuracy: {test_acc:.2f} %") |
72 | 62 | print("\n🔹 Classification Report:\n", classification_report(y_test, y_pred_ensemble)) |
73 | 63 |
|
74 | | -# ✅ Save Model |
| 64 | + |
75 | 65 | MODEL_PATH = "../models/k8s_failure_model.pkl" |
76 | 66 | joblib.dump(rf, MODEL_PATH) |
77 | 67 | model = joblib.load("models/k8s_failure_model.pkl") |
78 | 68 | print("The features in model are\n") |
79 | 69 | print(model.feature_names_in_) |
80 | 70 | print(f"\n✅ Model saved at {MODEL_PATH}") |
81 | 71 |
|
82 | | -# 🔥 Confusion Matrix Plot |
| 72 | +#confusion matrix to be plotted |
83 | 73 | cm = confusion_matrix(y_test, y_pred_ensemble) |
84 | 74 | plt.figure(figsize=(6, 4)) |
85 | 75 | sns.heatmap(cm, annot=True, fmt='d', cmap="Blues", xticklabels=["No Failure", "Failure"], yticklabels=["No Failure", "Failure"]) |
|
88 | 78 | plt.ylabel("Actual") |
89 | 79 | plt.show() |
90 | 80 |
|
91 | | -# 🔥 Feature Importance Plot |
| 81 | +#feature importance to be plotted |
92 | 82 | feature_importances = pd.DataFrame({'Feature': X_train.columns, 'Importance': rf.feature_importances_}) |
93 | 83 | feature_importances = feature_importances.sort_values(by='Importance', ascending=False).head(15) |
94 | | - |
95 | 84 | plt.figure(figsize=(10, 6)) |
96 | 85 | sns.barplot(x='Importance', y='Feature', data=feature_importances, palette="viridis") |
97 | 86 | plt.title("Top 15 Important Features") |
|
0 commit comments