Skip to content

Commit d3d2b55

Browse files
committed
Fixed test
1 parent 672559f commit d3d2b55

File tree

3 files changed

+15
-7
lines changed

3 files changed

+15
-7
lines changed
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from spam_detector_ai.classifiers import ClassifierType, NaiveBayesClassifier, RandomForestSpamClassifier, SVMClassifier
22

33
CLASSIFIER_MAP = {
4-
ClassifierType.NAIVE_BAYES: NaiveBayesClassifier,
5-
ClassifierType.RANDOM_FOREST: RandomForestSpamClassifier,
6-
ClassifierType.SVM: SVMClassifier
4+
ClassifierType.NAIVE_BAYES.value: NaiveBayesClassifier(),
5+
ClassifierType.RANDOM_FOREST.value: RandomForestSpamClassifier(),
6+
ClassifierType.SVM.value: SVMClassifier()
77
}

spam_detector_ai/prediction/predict.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import os
99

10+
from spam_detector_ai.classifiers import NaiveBayesClassifier, RandomForestSpamClassifier, SVMClassifier
1011
from spam_detector_ai.classifiers.classifier_map import CLASSIFIER_MAP
1112
from spam_detector_ai.classifiers.classifier_types import ClassifierType
1213
from spam_detector_ai.loading_and_processing import Preprocessor
@@ -49,11 +50,18 @@ class SpamDetector:
4950
"""This class is used to detect whether a message is spam or not spam."""
5051

5152
def __init__(self, model_type=ClassifierType.NAIVE_BAYES):
52-
classifier_class = CLASSIFIER_MAP.get(model_type)
53-
if not classifier_class:
53+
classifier_map = {
54+
ClassifierType.NAIVE_BAYES.value: NaiveBayesClassifier(),
55+
ClassifierType.RANDOM_FOREST.value: RandomForestSpamClassifier(),
56+
ClassifierType.SVM.value: SVMClassifier()
57+
}
58+
classifier = classifier_map.get(model_type.value)
59+
print(f"classifier_class: {classifier}")
60+
print(f"model_type: {model_type}")
61+
if not classifier:
5462
raise ValueError(f"Invalid model type: {model_type}")
5563

56-
self.model = classifier_class()
64+
self.model = classifier
5765
model_path, vectoriser_path = get_model_path(model_type)
5866
self.model.load_model(model_path, vectoriser_path)
5967
self.processor = Preprocessor()

spam_detector_ai/test_and_tuning/py_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_model():
1717
current_dir = os.path.dirname(os.path.abspath(__file__))
1818
base_dir = os.path.dirname(current_dir)
1919
data_path = os.path.join(base_dir, 'data/spam.csv')
20-
initial_trainer = ModelTrainer(data_path=data_path, classifier_type=None, logger=logger)
20+
initial_trainer = ModelTrainer(data_path=data_path, logger=logger)
2121
processed_data = initial_trainer.preprocess_data_()
2222
_, X_test, _, y_test = train_test_split(processed_data['processed_text'], processed_data['label'],
2323
test_size=0.2, random_state=0)

0 commit comments

Comments
 (0)