Skip to content

Commit 463bd40

Browse files
committed
Evaluate the sensitivity and specificity of the text classification model
1 parent 25526e9 commit 463bd40

File tree

1 file changed

+45
-26
lines changed

1 file changed

+45
-26
lines changed

training/text_classification.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from sklearn.calibration import CalibratedClassifierCV
55
from sklearn.ensemble import RandomForestClassifier
66
from sklearn.feature_extraction.text import TfidfVectorizer
7-
from sklearn.metrics import accuracy_score, classification_report
7+
from sklearn.metrics import accuracy_score, recall_score, classification_report
88
from sklearn.model_selection import train_test_split
99
from sklearn.naive_bayes import ComplementNB, MultinomialNB
1010
from sklearn.pipeline import Pipeline
@@ -19,7 +19,7 @@ def main():
1919
- min_df=5
2020
"""
2121
df = pd.read_csv("data.csv")
22-
print(df["rna_related"].value_counts())
22+
print(df["rna_related"].value_counts(), "\n")
2323
# rna_related
2424
# 1 3363
2525
# 0 3331
@@ -28,32 +28,54 @@ def main():
2828
y = df["rna_related"]
2929
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
3030

31-
pipeMNB = Pipeline(steps=[("tfidf", TfidfVectorizer()), ("clf", MultinomialNB())])
32-
pipeMNB.fit(X_train, y_train)
33-
predictMNB = pipeMNB.predict(X_test)
31+
classifiers = {
32+
"MultinomialNB": Pipeline(steps=[("tfidf", TfidfVectorizer()), ("clf", MultinomialNB())]),
33+
"ComplementNB": Pipeline(steps=[("tfidf", TfidfVectorizer()), ("clf", ComplementNB())]),
34+
"LinearSVC": Pipeline(steps=[("tfidf", TfidfVectorizer()), ("clf", CalibratedClassifierCV(LinearSVC()))]),
35+
"RandomForest": Pipeline(steps=[("tfidf", TfidfVectorizer()), ("clf", RandomForestClassifier())])
36+
}
3437

35-
pipeCNB = Pipeline(steps=[("tfidf", TfidfVectorizer()), ("clf", ComplementNB())])
36-
pipeCNB.fit(X_train, y_train)
37-
predictCNB = pipeCNB.predict(X_test)
38+
best_accuracy = 0
39+
best_classifier_name = None
40+
best_pipeline = None
3841

39-
pipeSVC = Pipeline(steps=[("tfidf", TfidfVectorizer()), ("clf", CalibratedClassifierCV(LinearSVC()))])
40-
pipeSVC.fit(X_train, y_train)
41-
predictSVC = pipeSVC.predict(X_test)
42+
for name, pipeline in classifiers.items():
43+
pipeline.fit(X_train, y_train)
44+
y_pred = pipeline.predict(X_test)
4245

43-
pipeRF = Pipeline(steps=[("tfidf", TfidfVectorizer()), ("clf", RandomForestClassifier())])
44-
pipeRF.fit(X_train, y_train)
45-
predictRF = pipeRF.predict(X_test)
46+
accuracy = accuracy_score(y_test, y_pred)
47+
sensitivity = recall_score(y_test, y_pred, pos_label=1)
48+
specificity = recall_score(y_test, y_pred, pos_label=0)
49+
print(
50+
f"{name} - Accuracy: {accuracy:.2f}, "
51+
f"Sensitivity: {sensitivity:.2f}, "
52+
f"Specificity: {specificity:.2f}"
53+
)
54+
# MultinomialNB - Accuracy: 0.93, Sensitivity: 0.99, Specificity: 0.88
55+
# ComplementNB - Accuracy: 0.93, Sensitivity: 0.99, Specificity: 0.88
56+
# LinearSVC - Accuracy: 0.98, Sensitivity: 0.98, Specificity: 0.98
57+
# RandomForest - Accuracy: 0.96, Sensitivity: 0.98, Specificity: 0.93
4658

47-
print(f"MNB: {accuracy_score(y_test, predictMNB):.2f}")
48-
print(f"CNB: {accuracy_score(y_test, predictCNB):.2f}")
49-
print(f"SVC: {accuracy_score(y_test, predictSVC):.2f}")
50-
print(f"RF: {accuracy_score(y_test, predictRF):.2f}")
51-
# MNB: 0.93
52-
# CNB: 0.93
53-
# SVC: 0.98
54-
# RF: 0.96
59+
if accuracy > best_accuracy:
60+
best_accuracy = accuracy
61+
best_classifier_name = name
62+
best_pipeline = pipeline
5563

56-
print(classification_report(y_test, predictSVC))
64+
if best_pipeline is not None:
65+
# save the best classifier pipeline
66+
joblib.dump(best_pipeline, f"{best_classifier_name}_pipeline.pkl")
67+
print(
68+
f"\nSaved the best classifier ({best_classifier_name}) "
69+
f"with accuracy {best_accuracy:.2f} "
70+
f"to '{best_classifier_name}_pipeline.pkl'"
71+
)
72+
# Saved the best classifier (LinearSVC) with accuracy 0.98 to 'LinearSVC_pipeline.pkl'
73+
74+
# display classification report for the best classifier pipeline
75+
print(f"\nClassification Report for {best_classifier_name}:")
76+
print(classification_report(y_test, classifiers[best_classifier_name].predict(X_test)))
77+
78+
# Classification Report for LinearSVC:
5779
# precision recall f1-score support
5880
#
5981
# 0 0.98 0.98 0.98 665
@@ -63,8 +85,5 @@ def main():
6385
# macro avg 0.98 0.98 0.98 1339
6486
# weighted avg 0.98 0.98 0.98 1339
6587

66-
joblib.dump(pipeSVC, "svc_pipeline.pkl")
67-
68-
6988
if __name__ == "__main__":
7089
main()

0 commit comments

Comments
 (0)