|
11 | 11 | from imblearn.over_sampling import BorderlineSMOTE
|
12 | 12 | from xgboost import XGBClassifier
|
13 | 13 |
|
14 |
| -# ✅ Load Dataset |
15 | 14 | CSV_PATH = "/home/pavithra/k8s-failure-prediction/data/merged_data.csv"
|
16 | 15 | df = pd.read_csv(CSV_PATH)
|
17 | 16 |
|
18 |
| -# ✅ Preprocessing |
19 | 17 | df.columns = df.columns.str.strip().str.replace(r'\s+', '_', regex=True).str.lower()
|
20 | 18 | df["timestamp"] = pd.to_datetime(df["timestamp"])
|
21 | 19 | df.set_index("timestamp", inplace=True)
|
22 | 20 |
|
23 |
| -# ✅ Feature Engineering |
24 | 21 | for col in df.columns:
|
25 | 22 | df[f"{col}_avg"] = df[col].rolling(window=5, min_periods=1).mean()
|
26 |
| - |
27 |
| -# ✅ Target Variable |
28 | 23 | df["target"] = (df["container_restart_count"].diff().fillna(0) > 1).astype(int)
|
29 | 24 | df.drop(columns=["container_restart_count"], inplace=True)
|
30 | 25 |
|
31 |
| -# ✅ Prepare Data |
32 | 26 | X = df.drop(columns=["target"])
|
33 | 27 | y = df["target"]
|
34 |
| - |
35 |
| -# ✅ Handle Class Imbalance |
| 28 | +# to handle the imbalance |
36 | 29 | if y.value_counts().min() >= 5:
|
37 | 30 | smote = BorderlineSMOTE(sampling_strategy='auto', random_state=42)
|
38 | 31 | X_resampled, y_resampled = smote.fit_resample(X, y)
|
39 | 32 | else:
|
40 | 33 | X_resampled, y_resampled = X, y
|
41 | 34 |
|
42 |
| -# ✅ Train-Test Split |
| 35 | +#splitting |
43 | 36 | X_train, X_test, y_train, y_test = train_test_split(X_resampled, y_resampled, test_size=0.2, random_state=42)
|
44 | 37 |
|
45 |
| -# ✅ Reduce Overfitting (Final Fix) |
| 38 | +#to reduce overfitting |
46 | 39 | rf = RandomForestClassifier(
|
47 |
| - n_estimators=300, # More trees |
48 |
| - max_depth=10, # Reduce tree depth |
49 |
| - min_samples_split=20, # More samples needed per split |
50 |
| - min_samples_leaf=10, # Prevent small branches |
| 40 | + n_estimators=300, |
| 41 | + max_depth=10, |
| 42 | + min_samples_split=20, |
| 43 | + min_samples_leaf=10, |
51 | 44 | bootstrap=True,
|
52 | 45 | random_state=42
|
53 | 46 | )
|
54 | 47 |
|
55 |
| -# ✅ Ensemble Model (Random Forest + XGBoost) |
56 | 48 | xgb = XGBClassifier(n_estimators=200, learning_rate=0.05, max_depth=7, subsample=0.8, colsample_bytree=0.8, random_state=42)
|
57 | 49 | rf.fit(X_train, y_train)
|
58 | 50 | xgb.fit(X_train, y_train)
|
59 | 51 |
|
60 |
| -# ✅ Predictions |
| 52 | +#predict |
61 | 53 | y_pred_rf = rf.predict(X_test)
|
62 | 54 | y_pred_xgb = xgb.predict(X_test)
|
63 | 55 |
|
64 |
| -# ✅ Combine Predictions (Soft Voting) |
| 56 | +#combining them |
65 | 57 | y_pred_ensemble = (y_pred_rf + y_pred_xgb) // 2
|
66 |
| - |
67 |
| -# ✅ Evaluate Model |
68 | 58 | train_acc = rf.score(X_train, y_train) * 100
|
69 | 59 | test_acc = accuracy_score(y_test, y_pred_ensemble) * 100
|
70 | 60 | print(f"\n🎯 Train Accuracy: {train_acc:.2f} %")
|
71 | 61 | print(f"🎯 Test Accuracy: {test_acc:.2f} %")
|
72 | 62 | print("\n🔹 Classification Report:\n", classification_report(y_test, y_pred_ensemble))
|
73 | 63 |
|
74 |
| -# ✅ Save Model |
| 64 | + |
75 | 65 | MODEL_PATH = "../models/k8s_failure_model.pkl"
|
76 | 66 | joblib.dump(rf, MODEL_PATH)
|
77 | 67 | model = joblib.load("models/k8s_failure_model.pkl")
|
78 | 68 | print("The features in model are\n")
|
79 | 69 | print(model.feature_names_in_)
|
80 | 70 | print(f"\n✅ Model saved at {MODEL_PATH}")
|
81 | 71 |
|
82 |
| -# 🔥 Confusion Matrix Plot |
| 72 | +#confusion matrix to be plotted |
83 | 73 | cm = confusion_matrix(y_test, y_pred_ensemble)
|
84 | 74 | plt.figure(figsize=(6, 4))
|
85 | 75 | sns.heatmap(cm, annot=True, fmt='d', cmap="Blues", xticklabels=["No Failure", "Failure"], yticklabels=["No Failure", "Failure"])
|
|
88 | 78 | plt.ylabel("Actual")
|
89 | 79 | plt.show()
|
90 | 80 |
|
91 |
| -# 🔥 Feature Importance Plot |
| 81 | +#feature importance to be plotted |
92 | 82 | feature_importances = pd.DataFrame({'Feature': X_train.columns, 'Importance': rf.feature_importances_})
|
93 | 83 | feature_importances = feature_importances.sort_values(by='Importance', ascending=False).head(15)
|
94 |
| - |
95 | 84 | plt.figure(figsize=(10, 6))
|
96 | 85 | sns.barplot(x='Importance', y='Feature', data=feature_importances, palette="viridis")
|
97 | 86 | plt.title("Top 15 Important Features")
|
|
0 commit comments