|
1 | 1 | # spam_detector_ai/classifiers/random_forest_classifier.py
|
2 | 2 |
|
3 |
| -import pickle |
4 |
| -from sklearn.feature_extraction.text import TfidfVectorizer |
5 |
| -from sklearn.ensemble import RandomForestClassifier |
6 | 3 | from imblearn.over_sampling import SMOTE
|
| 4 | +from sklearn.ensemble import RandomForestClassifier |
| 5 | +from sklearn.feature_extraction.text import TfidfVectorizer |
| 6 | + |
7 | 7 | from .base_classifier import BaseClassifier
|
8 | 8 |
|
9 | 9 |
|
10 | 10 | class RandomForestSpamClassifier(BaseClassifier):
|
11 | 11 | def __init__(self):
|
12 | 12 | super().__init__()
|
13 |
| - self.classifier = None |
14 |
| - self.vectoriser = TfidfVectorizer(max_features=1500, min_df=5, max_df=0.7) |
| 13 | + self.vectoriser = TfidfVectorizer(**BaseClassifier.VECTORIZER_PARAMS) |
15 | 14 | self.smote = SMOTE(random_state=42)
|
16 | 15 |
|
17 | 16 | def train(self, X_train, y_train):
|
18 | 17 | X_train_vectorized = self.vectoriser.fit_transform(X_train)
|
19 | 18 | X_train_res, y_train_res = self.smote.fit_resample(X_train_vectorized, y_train)
|
20 | 19 | self.classifier = RandomForestClassifier(n_estimators=100, random_state=0)
|
21 | 20 | self.classifier.fit(X_train_res, y_train_res)
|
22 |
| - |
23 |
| - def save_model(self, model_path, vectoriser_path): |
24 |
| - with open(model_path, 'wb') as file: |
25 |
| - pickle.dump(self.classifier, file) |
26 |
| - with open(vectoriser_path, 'wb') as file: |
27 |
| - pickle.dump(self.vectoriser, file) |
0 commit comments