diff --git a/code/__pycache__/svm_source_classifier.cpython-311.pyc b/code/__pycache__/svm_source_classifier.cpython-311.pyc new file mode 100644 index 0000000..3fd6afc Binary files /dev/null and b/code/__pycache__/svm_source_classifier.cpython-311.pyc differ diff --git a/code/svm_source_classifier.py b/code/svm_source_classifier.py new file mode 100644 index 0000000..2550989 --- /dev/null +++ b/code/svm_source_classifier.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- +""" +SVM_source_classifier.py +This code trains a SVM model to classify the source of a text document +based on writing style and topic using the fetch_20newsgroups dataset. +Removes headers, footers, and quotes to prevent data leakage. +Author: Jane Heng +Date: Oct.10.2025 +""" +from sklearn.datasets import fetch_20newsgroups +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.svm import SVC +from sklearn.pipeline import make_pipeline +from sklearn import metrics + +def load_data(categories): + """ + Loads training and testing data from the 20 Newsgroups dataset for the specified categories. + Removes headers, footers, and quotes to reduce noise and prevent data leakage. + + Args: + categories (list of str): List of category names to include. + + Returns: + tuple: A tuple containing: + - train_data (list of str): Training text samples. + - train_target (list of int): Labels for training samples. + - test_data (list of str): Testing text samples. + - test_target (list of int): Labels for testing samples. + - target_names (list of str): Names of the target categories. + """ + train_data = fetch_20newsgroups( + subset='train', + categories=categories, + remove=('headers', 'footers', 'quotes') + ) + test_data = fetch_20newsgroups( + subset='test', + categories=categories, + remove=('headers', 'footers', 'quotes') + ) + return train_data.data, train_data.target, test_data.data, test_data.target, train_data.target_names + +def train_model(X_train, y_train): + """ + Trains a Support Vector Machine (SVM) classifier using TF-IDF features. + + Args: + X_train (list of str): Training text data. + y_train (list of int): Corresponding labels for training data. + + Returns: + sklearn.pipeline.Pipeline: A trained pipeline combining TF-IDF vectorization and SVM classification. + """ + model = make_pipeline(TfidfVectorizer(), SVC(kernel='linear')) + model.fit(X_train, y_train) + return model + +def evaluate_model(model, X_test, y_test, target_names): + """ + Evaluates the trained model on the test data and prints a classification report. + + Args: + model (Pipeline): Trained model pipeline. + X_test (list of str): Test text data. + y_test (list of int): True labels for test data. + target_names (list of str): Names of the target categories. + + Returns: + None + """ + predicted = model.predict(X_test) + print(metrics.classification_report(y_test, predicted, target_names=target_names)) + +def run_source_classification(categories): + """ + Executes the full classification pipeline: + - Loads data for the specified categories + - Trains an SVM model + - Evaluates model performance + + Args: + categories (list of str): List of category names to classify. + + Returns: + sklearn.pipeline.Pipeline: The trained model pipeline. + """ + X_train, y_train, X_test, y_test, target_names = load_data(categories) + model = train_model(X_train, y_train) + evaluate_model(model, X_test, y_test, target_names) + return model + +if __name__ == "__main__": + categories = ['talk.politics.misc', 'rec.sport.hockey', 'sci.space'] + run_source_classification(categories) \ No newline at end of file diff --git a/code/test_source_classifier.py b/code/test_source_classifier.py new file mode 100644 index 0000000..8ed79a3 --- /dev/null +++ b/code/test_source_classifier.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- + +""" +test_source_classifier.py +Unit test for svm_source_classifier.py +Author: Jane Heng +Date: Oct.10.2025 +""" +import unittest +from svm_source_classifier import load_data, train_model, evaluate_model, run_source_classification +from unittest.mock import patch +import io + +class TestSourceClassifier(unittest.TestCase): + """ + Unit test suite for verifying the functionality of the SVM source classifier pipeline. + """ + + def setUp(self): + """ + Initializes sample input data and category labels. + """ + self.categories = ['sci.space', 'rec.sport.hockey', 'talk.politics.misc'] + self.X_sample = [ + "NASA launched a new satellite.", + "The hockey team won the championship.", + "Political debates are heating up." + ] + self.y_sample = [0, 1, 2] + + def test_load_data(self): + """ + Tests whether load_data correctly loads training and testing data. + Check if data lengths match and category names. + """ + X_train, y_train, X_test, y_test, target_names = load_data(self.categories) + self.assertEqual(len(X_train), len(y_train)) + self.assertEqual(len(X_test), len(y_test)) + self.assertEqual(len(target_names), len(self.categories)) + self.assertIn('sci.space', target_names) + print("test_load_data: PASS") + + def test_train_model(self): + """ + Tests if the model is properly trained on sample input. + """ + model = train_model(self.X_sample, self.y_sample) + self.assertTrue(hasattr(model, "predict")) + print("test_train_model: PASS") + + def test_evaluate_model(self): + """ + Tests whether evaluate_model function. + """ + model = train_model(self.X_sample, self.y_sample) + with patch('sys.stdout', new=io.StringIO()) as fake_out: + evaluate_model(model, self.X_sample, self.y_sample, self.categories) + output = fake_out.getvalue() + self.assertIn('sci.space', output) + print("test_evaluate_model: PASS") + + def test_run_source_classification(self): + """ + Tests whether run_source_classification function. + """ + model = run_source_classification(self.categories) + self.assertTrue(hasattr(model, "predict")) + print("test_run_source_classification: PASS") + +if __name__ == "__main__": + unittest.main() \ No newline at end of file