A novel architecture for medical text classification that uses a trainable sentence transformer with backpropagation to learn medically relevant text extraction through classification feedback. It is a work in progress.
This project implements a trainable embeddings model that learns to extract medically relevant sections from text through backpropagation from classification loss. The system uses:
- Trainable Sentence Transformer: Custom transformer that learns medical term extraction
- Attention Mechanism: Focuses on medically relevant parts of text
- Backpropagation: Classification loss improves the sentence transformer's extraction
- F1/RMSE Loss: Custom loss functions optimized for medical classification
- GCP Deployment: Ready for Google Cloud Platform training
Medical Text Input
↓
Sentence Transformer (Trainable)
↓
Attention Mechanism (Medical Terms)
↓
Classification Head
↓
F1 Loss + Backpropagation
↓
Improved Medical Extraction
Uses the Medical Abstracts TC Corpus with 5 classes:
- Neoplasms (3,163 samples)
- Digestive system diseases (1,494 samples)
- Nervous system diseases (1,925 samples)
- Cardiovascular diseases (3,051 samples)
- General pathological conditions (4,805 samples)
-
Clone the repository
git clone <repository-url> cd medical-text-classification
-
Setup virtual environment
./setup.sh source venv/bin/activate -
Install dependencies
pip install -r requirements.txt
# Train the model
python train.py --experiment-name my-experiment
# Train with custom configuration
python train.py --config config/custom_config.json# Single text prediction
python inference.py --model-path outputs/my-experiment_final_model.npz --text "Patient presents with chest pain..."
# Batch prediction from file
python inference.py --model-path outputs/my-experiment_final_model.npz --input-file test_texts.txt# Run comprehensive tests
python test_training.py{
"vocab_size": 10000,
"embedding_dim": 768,
"max_seq_length": 512,
"num_attention_heads": 12,
"num_classes": 5,
"learning_rate": 0.001
}{
"num_epochs": 20,
"batch_size": 32,
"learning_rate": 0.001,
"lambda_f1": 1.0,
"lambda_rmse": 0.1,
"patience": 5
}- Google Cloud Project with billing enabled
- Vertex AI API enabled
- Google Cloud SDK installed and configured
-
Build Docker image
gcloud builds submit --tag gcr.io/YOUR_PROJECT/medical-text-classifier . -
Run deployment script
chmod +x deploy_gcp.sh ./deploy_gcp.sh
-
Monitor training
# Check training status gcloud ai custom-jobs list --region=us-central1 # View logs gcloud ai custom-jobs describe JOB_ID --region=us-central1
- Compute: ~$0.38/hour (n1-standard-8)
- GPU: ~$0.35/hour (NVIDIA T4)
- Storage: ~$0.026/GB/month
- Estimated total for 2-hour training: ~$1.50
- Word Embeddings: Learned embeddings for medical vocabulary
- Attention Mechanism: Multi-head attention for medical term extraction
- Medical Term Weights: Learned importance weights for medical terms
- Backpropagation: Gradients flow from classification loss to transformer
- F1 Loss: Optimized for imbalanced medical classification
- RMSE Loss: Regularizes attention distribution
- Combined Loss: Weighted combination of F1 and RMSE
- Forward pass through sentence transformer
- Extract medical-relevant embeddings via attention
- Classification prediction
- Compute F1 loss
- Backpropagate through transformer and classifier
- Update all parameters
The system demonstrates:
- Medical Term Extraction: Learns to focus on clinically relevant text
- Improved Classification: Backpropagation enhances medical understanding
- Interpretability: Attention weights show important medical terms
- Scalability: Ready for large-scale medical text processing
medical-text-classification/
├── src/
│ ├── trainable_sentence_transformer.py # Core model
│ ├── loss_functions.py # Custom losses
│ └── training.py # Training loop
├── train.py # Training script
├── inference.py # Inference script
├── test_training.py # Test suite
├── gcp_config.py # GCP deployment
├── Dockerfile # Container config
├── requirements.txt # Dependencies
├── setup.sh # Environment setup
└── data/ # Dataset
├── medical_tc_train.csv
├── medical_tc_test.csv
└── medical_tc_labels.csv
This project is licensed under the MIT License - see the LICENSE file for details.
- Fork the repository
- Create a feature branch
- Make your changes
- Add tests
- Submit a pull request
For questions or issues, please open a GitHub issue or contact the maintainers.