Skip to content

Commit b6e1dce

Browse files
committed
temporary fix for path problem
1 parent b4465e8 commit b6e1dce

File tree

2 files changed

+17
-11
lines changed

2 files changed

+17
-11
lines changed

spam_detector_ai/trainer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
# spam_detector_ai/trainer.py
2+
import os
3+
import sys
4+
from pathlib import Path
25

36
from sklearn.model_selection import train_test_split
47

8+
project_root = Path(__file__).parent.parent
9+
sys.path.append(str(project_root))
10+
511
from classifiers.classifier_types import ClassifierType
612
from logger_config import init_logging
713
from training.train_models import ModelTrainer
@@ -18,7 +24,8 @@ def train_model(classifier_type, model_filename, vectoriser_filename, X_train, y
1824

1925
if __name__ == '__main__':
2026
# Load and preprocess data once
21-
initial_trainer = ModelTrainer(data_path='data/spam.csv', logger=logger)
27+
data_path = os.path.join(project_root, 'spam_detector_ai', 'data', 'spam.csv')
28+
initial_trainer = ModelTrainer(data_path=data_path, logger=logger)
2229
processed_data = initial_trainer.preprocess_data_()
2330

2431
# Split the data once

spam_detector_ai/training/train_models.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# spam_detector_ai/training/train_models.py
22

33
import os
4+
from pathlib import Path
45

56
from sklearn.model_selection import train_test_split
67

@@ -77,19 +78,17 @@ def get_directory_path(self):
7778
raise ValueError(f"Invalid classifier type: {self.classifier_type}")
7879

7980
def save_model(self, model_filename, vectoriser_filename):
80-
# Determine the directory of this file
81-
current_dir = os.path.dirname(os.path.abspath(__file__))
82-
# Assuming the spam_detector_ai directory is one level up from the current directory
83-
base_dir = os.path.dirname(current_dir)
81+
# Use the project root to construct the paths
82+
project_root = Path(__file__).parent.parent
83+
models_dir = project_root
8484
directory_path = self.get_directory_path()
8585

86-
# Ensure the directory exists
87-
if not os.path.exists(directory_path):
88-
os.makedirs(directory_path)
86+
model_filepath = models_dir / directory_path / model_filename
87+
vectoriser_filepath = models_dir / directory_path / vectoriser_filename
8988

90-
model_filepath = os.path.join(base_dir, directory_path, model_filename)
91-
vectoriser_filepath = os.path.join(base_dir, directory_path, vectoriser_filename)
89+
# Ensure the directory exists
90+
model_filepath.parent.mkdir(parents=True, exist_ok=True)
9291

9392
self.logger.info(f'Saving model to {model_filepath}')
94-
self.classifier.save_model(model_filepath, vectoriser_filepath)
93+
self.classifier.save_model(str(model_filepath), str(vectoriser_filepath))
9594
self.logger.info('Model saved.\n')

0 commit comments

Comments
 (0)