Skip to content

Commit 672559f

Browse files
committed
Optimized functions & changed from pickle to joblib
1 parent cc2b8d1 commit 672559f

35 files changed

+251
-76
lines changed

.github/workflows/tests.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,5 @@ jobs:
3838
- name: Test with pytest
3939
run: |
4040
pytest
41+
python -m unittest spam_detector_ai.tests.test_dataloader
42+
python -m unittest spam_detector_ai.tests.test_preprocessor

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,12 @@ The test results are shown below:
139139
| | Predicted: Ham | Predicted: Spam |
140140
|------------------|----------------------|---------------------|
141141
| **Actual: Ham** | 2080 (True Negative) | 25 (False Positive) |
142-
| **Actual: Spam** | 42 (False Negative) | 812 (True Positive) |
142+
| **Actual: Spam** | 41 (False Negative) | 813 (True Positive) |
143143

144144
- True Negative (TN): 2080 messages were correctly identified as ham (non-spam).
145145
- False Positive (FP): 25 ham messages were incorrectly identified as spam.
146-
- False Negative (FN): 42 spam messages were incorrectly identified as ham.
147-
- True Positive (TP): 812 messages were correctly identified as spam.
146+
- False Negative (FN): 41 spam messages were incorrectly identified as ham.
147+
- True Positive (TP): 813 messages were correctly identified as spam.
148148

149149
##### Performance Metrics:
150150

@@ -212,7 +212,7 @@ The project contains 3 pre-trained models that can be used directly if you want
212212
If you don't want to use the package, you can use the API that I have deployed
213213
[here](https://spam-detection-api.adamspierredavid.com/).
214214

215-
The API is built with Django and the following is an example of how I use it in a personal project:
215+
The API is built with Django, and the following is an example of how I use it in a personal project:
216216

217217
![Screenshot](./screenshots/spam-detection-api-example.png)
218218

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ nltk~=3.8.1
55
setuptools==69.0.3
66
pytest==7.4.4
77
requests~=2.31.0
8-
imblearn~=0.0
8+
imblearn~=0.0
9+
joblib~=1.3.2

spam_detector_ai/classifiers/base_classifier.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,28 @@
22

33
from abc import ABC, abstractmethod
44

5+
from joblib import dump, load
6+
57

68
class BaseClassifier(ABC):
9+
VECTORIZER_PARAMS = {
10+
'max_features': 1500,
11+
'min_df': 5,
12+
'max_df': 0.7
13+
}
14+
715
def __init__(self):
8-
pass
16+
self.classifier = None
17+
self.vectoriser = None
918

1019
@abstractmethod
1120
def train(self, X_train, y_train):
1221
pass
1322

14-
@abstractmethod
1523
def save_model(self, model_path, vectoriser_path):
16-
pass
24+
dump(self.classifier, model_path)
25+
dump(self.vectoriser, vectoriser_path)
26+
27+
def load_model(self, model_path, vectoriser_path):
28+
self.classifier = load(model_path)
29+
self.vectoriser = load(vectoriser_path)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from spam_detector_ai.classifiers import ClassifierType, NaiveBayesClassifier, RandomForestSpamClassifier, SVMClassifier
2+
3+
CLASSIFIER_MAP = {
4+
ClassifierType.NAIVE_BAYES: NaiveBayesClassifier,
5+
ClassifierType.RANDOM_FOREST: RandomForestSpamClassifier,
6+
ClassifierType.SVM: SVMClassifier
7+
}
Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# spam_detector_ai/classifiers/naive_bayes_classifier.py
22

3-
import pickle
4-
53
from sklearn.feature_extraction.text import CountVectorizer
64
from sklearn.naive_bayes import MultinomialNB
75

@@ -11,16 +9,9 @@
119
class NaiveBayesClassifier(BaseClassifier):
1210
def __init__(self):
1311
super().__init__()
14-
self.classifier = None
15-
self.vectoriser = CountVectorizer(max_features=1500, min_df=5, max_df=0.7)
12+
self.vectoriser = CountVectorizer(**BaseClassifier.VECTORIZER_PARAMS)
1613

1714
def train(self, X_train, y_train):
1815
X_train_vectorized = self.vectoriser.fit_transform(X_train).toarray()
1916
self.classifier = MultinomialNB()
2017
self.classifier.fit(X_train_vectorized, y_train)
21-
22-
def save_model(self, model_path, vectoriser_path):
23-
with open(model_path, 'wb') as file:
24-
pickle.dump(self.classifier, file)
25-
with open(vectoriser_path, 'wb') as file:
26-
pickle.dump(self.vectoriser, file)
Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,20 @@
11
# spam_detector_ai/classifiers/random_forest_classifier.py
22

3-
import pickle
4-
from sklearn.feature_extraction.text import TfidfVectorizer
5-
from sklearn.ensemble import RandomForestClassifier
63
from imblearn.over_sampling import SMOTE
4+
from sklearn.ensemble import RandomForestClassifier
5+
from sklearn.feature_extraction.text import TfidfVectorizer
6+
77
from .base_classifier import BaseClassifier
88

99

1010
class RandomForestSpamClassifier(BaseClassifier):
1111
def __init__(self):
1212
super().__init__()
13-
self.classifier = None
14-
self.vectoriser = TfidfVectorizer(max_features=1500, min_df=5, max_df=0.7)
13+
self.vectoriser = TfidfVectorizer(**BaseClassifier.VECTORIZER_PARAMS)
1514
self.smote = SMOTE(random_state=42)
1615

1716
def train(self, X_train, y_train):
1817
X_train_vectorized = self.vectoriser.fit_transform(X_train)
1918
X_train_res, y_train_res = self.smote.fit_resample(X_train_vectorized, y_train)
2019
self.classifier = RandomForestClassifier(n_estimators=100, random_state=0)
2120
self.classifier.fit(X_train_res, y_train_res)
22-
23-
def save_model(self, model_path, vectoriser_path):
24-
with open(model_path, 'wb') as file:
25-
pickle.dump(self.classifier, file)
26-
with open(vectoriser_path, 'wb') as file:
27-
pickle.dump(self.vectoriser, file)
Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,17 @@
11
# spam_detector_ai/classifiers/svm_classifier.py
22

3-
import pickle
43
from sklearn.feature_extraction.text import TfidfVectorizer
54
from sklearn.svm import SVC
5+
66
from .base_classifier import BaseClassifier
77

88

99
class SVMClassifier(BaseClassifier):
1010
def __init__(self):
1111
super().__init__()
12-
self.classifier = None
13-
self.vectoriser = TfidfVectorizer(max_features=1500, min_df=5, max_df=0.7)
12+
self.vectoriser = TfidfVectorizer(**BaseClassifier.VECTORIZER_PARAMS)
1413

1514
def train(self, X_train, y_train):
1615
X_train_vectorized = self.vectoriser.fit_transform(X_train)
17-
self.classifier = SVC(C=100, gamma=1, kernel='rbf')
16+
self.classifier = SVC(C=10, gamma=1, kernel='rbf')
1817
self.classifier.fit(X_train_vectorized, y_train)
19-
20-
def save_model(self, model_path, vectoriser_path):
21-
with open(model_path, 'wb') as file:
22-
pickle.dump(self.classifier, file)
23-
with open(vectoriser_path, 'wb') as file:
24-
pickle.dump(self.vectoriser, file)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# spam_detector_ai/loading_and_processing/__init__.py
22

33
from .data_loader import DataLoader
4-
from .preprocessor import Preprocessor
4+
from .preprocessor import Preprocessor

spam_detector_ai/loading_and_processing/data_loader.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,17 @@
55

66
class DataLoader:
77
def __init__(self, data_path):
8-
self.data = pd.read_csv(data_path)
8+
if not data_path.endswith('.csv'):
9+
raise ValueError("Only CSV files are supported")
10+
try:
11+
self.data = pd.read_csv(data_path)
12+
except FileNotFoundError:
13+
raise FileNotFoundError(f"The file at {data_path} was not found.")
14+
except Exception as e:
15+
raise Exception(f"An error occurred while loading the file: {e}")
916

1017
def get_data(self):
18+
"""
19+
Return the loaded data.
20+
"""
1121
return self.data

0 commit comments

Comments
 (0)