A state-of-the-art SMS spam detection system using DistilBERT with advanced preprocessing and hyperparameter optimization.
๐ Documentation โข ๐ Quick Start โข ๐ Performance โข ๐ค Contributing
This implementation targets practical mobile deployment without sacrificing accuracy:
- Distillation & small base model: DistilBERT retains ~97% of BERT performance with โ40% fewer parameters (โ66M vs 110M), simplifying on-device memory use.
- Limited fine-tuning footprint: Only last transformer layers + classifier are trainable, reducing retraining cost and enabling lightweight updates.
- Export-ready paths: TorchScript and ONNX export scripts are included to convert models for Android/iOS (via PyTorch Mobile / ONNX Runtime Mobile).
- Quantization & pruning support: Model and training design allow 8-bit quantization and structured pruning with minimal accuracy loss.
- Low-latency considerations: batching, input length cap (128 tokens), attention caching and optimized tokenization reduce per-message inference time.
- Explainability: attention maps provide diagnostic signals for mobile A/B testing and on-device explainability features.
graph TD
A[Input Layer] --> B[DistilBERT Base]
B --> C[Transformer Layer 1]
C --> D[Transformer Layer 2]
D --> E[Transformer Layer 3]
E --> F[Transformer Layer 4]
F --> G[Transformer Layer 5]
G --> H[Transformer Layer 6]
H --> I[Classification Head]
I --> J[Output]
subgraph Architecture Details
K[Hidden Size: 768]
L[Attention Heads: 12]
M[Parameters: ~66M]
end
pie title Model Size Comparison
"DistilBERT (Our Model)" : 66
"BERT Base" : 110
graph TD
A[SMS Input] --> B[Text Preprocessing]
B --> C[DistilBERT Model]
C --> D[Classification]
D --> E[Spam/Ham Output]
subgraph Preprocessing
B --> B1[Text Cleaning]
B1 --> B2[Abbreviation Handling]
B2 --> B3[URL/Email Detection]
B3 --> B4[Emoji Processing]
end
subgraph Model Processing
C --> C1[Tokenization]
C1 --> C2[Feature Extraction]
C2 --> C3[Attention Mechanism]
end
graph LR
A[Data Collection] --> B[Preprocessing]
B --> C[Model Training]
C --> D[Evaluation]
D --> E[Deployment & Export]
style A fill:#FF6A00,stroke:#000000,color:#ffffff
style B fill:#0052CC,stroke:#000000,color:#ffffff
style C fill:#4C4C4C,stroke:#000000,color:#ffffff
style D fill:#00A14B,stroke:#000000,color:#ffffff
style E fill:#6C757D,stroke:#000000,color:#ffffff
flowchart LR
Raw[Raw SMS] --> Clean[Lowercase & Trim]
Clean --> Mask[URL / Email / Phone Masking]
Mask --> Abbrev[Abbreviation Expansion]
Abbrev --> Repeat[Repeated-char normalization]
Repeat --> Emoji[Emoji normalization -> <emoji>]
Emoji --> Token[Tokenization & Lemmatization]
Token --> Stop[Selective Stopword Filtering]
Stop --> Output[Processed Text]
style Raw fill:#FF6A00,stroke:#000000,color:#ffffff
style Clean fill:#00B4D8,stroke:#000000,color:#000000
style Mask fill:#333333,stroke:#000000,color:#ffffff
style Abbrev fill:#00A14B,stroke:#000000,color:#ffffff
style Repeat fill:#6C757D,stroke:#000000,color:#ffffff
style Emoji fill:#FF6A00,stroke:#000000,color:#ffffff
style Token fill:#00B4D8,stroke:#000000,color:#000000
style Stop fill:#333333,stroke:#000000,color:#ffffff
style Output fill:#ffffff,stroke:#000000,color:#000000
- ๐ Text Normalization: Convert to lowercase, remove extra spaces
- ๐ฌ Abbreviation Handling: u โ you, ur โ your
- ๐ Pattern Detection: URLs, emails, phone numbers
- ๐ Emoji Processing: Standardization and handling
- โก Smart Stop Word Filtering: Context-aware removal
๐ง DistilBERT Architecture
- Fine-tuned transformer model for SMS classification
- Selective layer unfreezing (last 2 layers trainable)
- Attention visualization capabilities
- Mobile-optimized architecture (66M vs 110M parameters)
- 97% of BERT performance with 40% smaller size
๐ Comprehensive Analysis
- Stratified data splitting (70% train, 15% val, 15% test)
- Class imbalance handling with weighted loss
- Performance visualization and metrics tracking
- Confusion matrix analysis and attention heatmaps
โก Hyperparameter Optimization
- Optuna integration with Bayesian optimization
- Multi-objective optimization (F1 score maximization)
- Early stopping and pruning mechanisms
- Cross-validation support
๐ฏ High Performance
- 99.28% accuracy on test set
- 97.35% F1 score
- 96.49% precision
- 98.21% recall
- Robust against overfitting
- ๐ง Installation
- ๐ Project Structure
- ๐ Quick Start
- ๐ Complete Code Documentation
- ๐ Performance Metrics
- ๐ฌ Technical Specifications
- ๐ก Usage Examples
- ๐ค Contributing
- ๐ License
- ๐ Acknowledgments
- ๐ Citations
- ๐ฎ Future Roadmap
- Python: 3.7+
- CUDA: 11.0+ (optional, for GPU acceleration)
- RAM: 8GB+
- Storage: 2GB+
git clone https://github.com/your-username/sms-spam-detection.git
cd sms-spam-detection
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
pip install -r requirements.txtpip install transformers==4.57.0 torch==2.8.0 pandas==2.3.2 numpy==2.3.3 scikit-learn==1.7.2 optuna==4.5.0 nltk==3.9.2 spacy==3.8.7 emoji==2.15.0 regex==2025.9.18 matplotlib==3.10.7 seaborn==0.13.2 tqdm==4.67.1sms-spam-detection/
โโโ optimized.ipynb # Main implementation notebook
โโโ README.md # Documentation
โโโ requirements.txt # Dependencies
โโโ LICENSE # MIT License
โโโ datasets/
โ โโโ spam.csv # SMS spam dataset (5,572 messages)
โโโ models/
โ โโโ best_model.pt # Best performing model weights
โ โโโ model_config.json # Model configuration
โโโ visualizations/
โ โโโ confusion_matrix.png
โ โโโ training_history.png
โ โโโ attention_maps/
โโโ docs/
โ โโโ api_reference.md
โ โโโ deployment_guide.md
โ โโโ troubleshooting.md
โโโ src/
โโโ preprocessor.py
โโโ model.py
โโโ utils.py
-
Run the Notebook
jupyter lab optimized.ipynb
-
Quick Prediction Example
from transformers import DistilBertTokenizer import torch tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') model = SMSSpamClassifier() model.load_state_dict(torch.load('models/best_model.pt')) model.eval() def predict_spam(text): encoding = tokenizer(text, return_tensors='pt', max_length=128, truncation=True, padding=True) with torch.no_grad(): outputs = model(encoding['input_ids'], encoding['attention_mask']) prediction = torch.softmax(outputs, dim=1) confidence = prediction.max().item() result = "SPAM" if prediction.argmax() == 1 else "HAM" return result, confidence examples = [ "Hey, are we still meeting for lunch tomorrow?", "FREE! Win ยฃ1000 cash! Text WIN to 81010 now!", "Your Amazon order has been dispatched" ] for msg in examples: result, confidence = predict_spam(msg) print(f"Message: '{msg[:40]}...' Prediction: {result} (Confidence: {confidence:.2%})")
See optimized.ipynb for a cell-by-cell walkthrough with code and explanation covering:
- Imports & Setup: All dependencies, reproducibility, and device selection.
- Data Loading: SMS dataset structure, distribution, and statistics.
- Preprocessing: SMS normalization, abbreviation expansion, emoji/money/percentage handling, tokenization, lemmatization.
- Model Architecture: DistilBERT classifier, selective layer freezing, attention weights.
- Training Setup: Stratified splits, custom dataset, weighted loss for imbalance, AdamW optimizer, scheduler, early stopping.
- Training & Validation: Training/evaluation functions, metrics tracking, early stopping.
- Hyperparameter Optimization: Optuna Bayesian search, pruning, cross-validation.
- Evaluation & Visualization: Metrics, confusion matrix, training history, attention heatmaps.
| Metric | Train | Validation | Test |
|---|---|---|---|
| Accuracy | 99.18% | 99.28% | 99.28% |
| Precision | 98.82% | 96.49% | 96.49% |
| Recall | 99.21% | 98.21% | 98.21% |
| F1 Score | 99.01% | 97.35% | 97.35% |
| Loss | 0.0835 | 0.0456 | 0.0456 |
- Model: DistilBERT-base-uncased (6 layers, 768 hidden, 12 heads)
- Trainable Params: Last 2 layers + classifier (~8M params)
- Max Sequence Length: 128 tokens
- Data Split: 70:15:15 (train/val/test), stratified
- Weighted Loss: Spam class weighted by imbalance ratio (~6.46:1)
- Optimizer: AdamW, lr=2e-5, weight_decay=0.01
- Scheduler: Linear with warmup
- Early Stopping: Patience=3
- Frameworks: PyTorch, Hugging Face Transformers, Optuna
Single Message Classification
def classify_message(text, return_confidence=False):
preprocessor = SMSPreprocessor()
processed_text = preprocessor.preprocess(text)
encoding = tokenizer(processed_text, return_tensors='pt', max_length=128, truncation=True, padding=True)
with torch.no_grad():
outputs = model(encoding['input_ids'], encoding['attention_mask'])
probabilities = torch.softmax(outputs, dim=1)
prediction = outputs.argmax(dim=1).item()
confidence = probabilities.max().item()
result = "SPAM" if prediction == 1 else "HAM"
if return_confidence:
return result, confidence
return result
messages = [
"Hey, are we still meeting for lunch tomorrow at 12?",
"URGENT! You've won ยฃ1000! Text WIN to 81010 to claim now!",
"Your Amazon order #123-456 has been dispatched",
"Free entry in 2 a wkly comp to win FA Cup final tkts"
]
for msg in messages:
result, confidence = classify_message(msg, return_confidence=True)
print(f"Message: '{msg[:40]}...' Classification: {result} (Confidence: {confidence:.1%})")Batch Processing
def classify_batch(messages, batch_size=32):
preprocessor = SMSPreprocessor()
processed_messages = [preprocessor.preprocess(msg) for msg in messages]
results = []
for i in range(0, len(processed_messages), batch_size):
batch = processed_messages[i:i+batch_size]
encodings = tokenizer(batch, return_tensors='pt', max_length=128, truncation=True, padding=True)
with torch.no_grad():
outputs = model(encodings['input_ids'], encodings['attention_mask'])
probabilities = torch.softmax(outputs, dim=1)
predictions = outputs.argmax(dim=1).cpu().numpy()
confidences = probabilities.max(dim=1).cpu().numpy()
for j, (pred, conf) in enumerate(zip(predictions, confidences)):
results.append((messages[i + j], "SPAM" if pred == 1 else "HAM", conf))
return resultsWe welcome contributions from the community!
- Bug Reports: Check existing issues, use the template (see above for details), and give full traceback if applicable.
- Feature Requests: Describe your need, use case, and possible implementation.
- Pull Requests: Fork repo, create feature branch, write tests, format/lint code, update docs, and submit PR.
Code Standards:
- PEP 8 style
- Docstrings and type hints
- Comprehensive tests and documentation
This project is licensed under the MIT License - see the LICENSE file for details.
- SMS Spam Collection: Tiago A. Almeida & Josรฉ Marรญa Gรณmez Hidalgo
- DistilBERT: Victor Sanh et al. (Hugging Face)
- Transformer: Vaswani et al.
- Libraries: Hugging Face, PyTorch, Optuna, Scikit-learn, NLTK, spaCy
@article{sms_spam_detection_2024,
title={Advanced SMS Spam Detection with DistilBERT: A Comprehensive Implementation},
author={Your Name},
year={2024},
journal={GitHub Repository},
url={https://github.com/your-username/sms-spam-detection}
}
@article{almeida2011sms,
title={SMS Spam Collection Data Set},
author={Almeida, Tiago A and Hidalgo, Josรฉ Marรญa Gรณmez},
journal={UCI Machine Learning Repository},
year={2011}
}
@article{sanh2019distilbert,
title={DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter},
author={Sanh, Victor and Debut, Lysandre and Chaumond, Julien and Wolf, Thomas},
journal={arXiv preprint arXiv:1910.01108},
year={2019}
}- Multi-language support (Spanish, French, German, etc.)
- Real-time RESTful API
- Mobile deployment (TF Lite/ONNX)
- Active learning, federated learning, adversarial robustness
- Explainable AI, multimodal support, time-series pattern detection