Skip to content

Commit a90df27

Browse files
authored
Create ensemble_models.py
1 parent 00dcaab commit a90df27

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import pandas as pd
2+
import numpy as np
3+
import sklearn
4+
import os
5+
import pickle
6+
import copy
7+
from sklearn.metrics import roc_curve, auc
8+
from sklearn.metrics import accuracy_score, classification_report # 用于评估准确率和分类报告
9+
import joblib
10+
11+
def ensemable_classcification(X, y = None, rt_type='Prob'):
12+
# 加载 LabelEncoder
13+
label_encoder_path = "label_encoder.pkl"
14+
label_encoder = joblib.load(label_encoder_path)
15+
print(f"LabelEncoder已加载,类别标签为: {label_encoder.classes_}")
16+
17+
# 加载模型
18+
svm_model_path = "svm_model.pkl"
19+
svm_model = joblib.load(svm_model_path)
20+
print(f"SVM模型已加载")
21+
22+
log_reg_model_path = "log_reg_model.pkl"
23+
log_reg_model = joblib.load(log_reg_model_path)
24+
print(f"Logistic回归模型已加载")
25+
26+
xgb_model_path = "xgb_model.pkl"
27+
xgb_model = joblib.load(xgb_model_path)
28+
print(f"XGBoost模型已加载")
29+
30+
if y:
31+
print(classification_report(y, svm_model.predict(X), target_names=label_encoder.classes_))
32+
print(classification_report(y, log_reg_model.predict(X), target_names=label_encoder.classes_))
33+
print(classification_report(y, xgb_model.predict(X), target_names=label_encoder.classes_))
34+
35+
if rt_type == 'Class':
36+
svm_result = svm_model.predict(X)
37+
log_reg_result = log_reg_model.predict(X)
38+
xgb_result = xgb_model.predict(X)
39+
return svm_result, log_reg_result, xgb_result
40+
elif rt_type == 'Prob':
41+
svm_result = svm_model.predict_proba(X)
42+
log_reg_result = log_reg_model.predict_proba(X)
43+
xgb_result = xgb_model.predict_proba(X)
44+
return svm_result, log_reg_result, xgb_result
45+
elif rt_type == 'mean_Prob':
46+
svm_result = svm_model.predict_proba(X)
47+
log_reg_result = log_reg_model.predict_proba(X)
48+
xgb_result = xgb_model.predict_proba(X)
49+
mean_prob = {label_encoder.classes_[i]: np.mean([svm_result[:, i], log_reg_result[:, i], xgb_result[:, i]], axis=0)
50+
for i in range(len(label_encoder.classes_))}
51+
return mean_prob
52+
else:
53+
raise TypeError('Please provide a legal return param.')

0 commit comments

Comments
 (0)