Skip to content

Commit 906eda6

Browse files
committed
small changes
1 parent 12505f6 commit 906eda6

File tree

8 files changed

+55
-9
lines changed

8 files changed

+55
-9
lines changed

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ The more data you have, the better the models will perform.
5555
To train the models, run the following command:
5656

5757
```sh
58-
python trainer.py
58+
python3 spam_detector_ai/trainer.py
5959
```
6060

6161
This will train all the models and save them to the `models` directory. For now, there is 3 models:
@@ -66,9 +66,9 @@ This will train all the models and save them to the `models` directory. For now,
6666

6767
### Tests
6868

69-
The tests results are shown below:
69+
The test results are shown below:
7070

71-
#### <u>Model: NAIVE_BAYES</u>
71+
#### _Model: NAIVE_BAYES_
7272

7373
##### Confusion Matrix:
7474

@@ -96,7 +96,7 @@ The tests results are shown below:
9696

9797
<br>
9898

99-
#### <u>Model: RANDOM_FOREST</u>
99+
#### _Model: RANDOM_FOREST_
100100

101101
##### Confusion Matrix:
102102

@@ -124,7 +124,7 @@ The tests results are shown below:
124124

125125
<br>
126126

127-
#### <u>Model: SVM</u>
127+
#### _Model: SVM_
128128

129129
##### Confusion Matrix:
130130

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
scikit-learn~=1.2.2
22
imblearn~=0.0
33
pandas~=2.0.2
4-
nltk~=3.8.1
4+
nltk~=3.8.1
5+
setuptools~=67.8.0
6+
pytest~=7.3.2

spam_detector_ai/loading_and_processing/data_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# data_loader.py
1+
# spam_detector_ai/loading_and_processing/data_loader.py
22

33
import pandas as pd
44

spam_detector_ai/loading_and_processing/preprocessor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# preprocessor.py
1+
# spam_detector_ai/loading_and_processing/preprocessor.py
22

33
import re
44

spam_detector_ai/prediction/predict.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
# spam_detector_ai/prediction/predict.py
2+
"""
3+
Author: Adams P. David
4+
Contact: https://adamspierredavid.com/contact/
5+
Date Written: 2023-06-12
6+
"""
27

38
import os
49
import pickle
@@ -41,6 +46,8 @@ def get_model_path(model_type):
4146

4247

4348
class SpamDetector:
49+
"""This class is used to detect whether a message is spam or not spam."""
50+
4451
def __init__(self, model_type=ClassifierType.NAIVE_BAYES):
4552
# Determine paths based on model's type
4653
model_path, vectoriser_path = get_model_path(model_type)
@@ -79,6 +86,9 @@ def test_is_spam(self, message_):
7986

8087

8188
class VotingSpamDetector:
89+
"""This class is used to detect whether a message is spam
90+
or not spam using majority voting of multiple spam detectors models."""
91+
8292
def __init__(self):
8393
self.detectors = [
8494
SpamDetector(model_type=ClassifierType.NAIVE_BAYES),
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .test import TestModel
1+
from .test import TestModel
2+
from .py_test import TestClassifiers

spam_detector_ai/test_and_tuning/fine_tuning_svm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# spam_detector_ai/test_and_tuning/fine_tuning_svm.py
2+
13
from sklearn.feature_extraction.text import TfidfVectorizer
24
from sklearn.model_selection import GridSearchCV
35
from sklearn.svm import SVC
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import os
2+
import pytest
3+
from sklearn.metrics import accuracy_score
4+
from sklearn.model_selection import train_test_split
5+
from spam_detector_ai.classifiers.classifier_types import ClassifierType
6+
from spam_detector_ai.logger_config import init_logging
7+
from spam_detector_ai.prediction import SpamDetector
8+
from spam_detector_ai.training import ModelTrainer
9+
10+
11+
@pytest.fixture(scope="module")
12+
def test_model():
13+
classifier_types = [ClassifierType.NAIVE_BAYES, ClassifierType.RANDOM_FOREST, ClassifierType.SVM]
14+
logger = init_logging()
15+
current_dir = os.path.dirname(os.path.abspath(__file__))
16+
base_dir = os.path.dirname(current_dir)
17+
data_path = os.path.join(base_dir, 'data/spam.csv')
18+
initial_trainer = ModelTrainer(data_path=data_path, classifier_type=None, logger=logger)
19+
processed_data = initial_trainer._preprocess_data()
20+
_, X_test, _, y_test = train_test_split(processed_data['processed_text'], processed_data['label'],
21+
test_size=0.2, random_state=0)
22+
return classifier_types, X_test, y_test
23+
24+
25+
class TestClassifiers:
26+
def test_classifier_accuracy(self, test_model):
27+
classifier_types, X_test, y_test = test_model
28+
for ct in classifier_types:
29+
detector = SpamDetector(model_type=ct)
30+
y_pred = [detector.test_is_spam(message) for message in X_test]
31+
assert accuracy_score(y_test, y_pred) > 0.85

0 commit comments

Comments
 (0)