This project provides a robust, end-to-end pipeline for multi-label classification of chest X-ray images using deep learning and transfer learning. The model predicts the presence of multiple thoracic diseases in a single X-ray image, leveraging modern transformer-based architectures.
- Features
- Technologies Used
- Project Structure
- Dataset
- Installation
- Usage
- Model Training and Evaluation
- Testing the Model
- Results
- Future Improvements
- License
- Acknowledgments
- Handles multi-label classification for medical images (14 thoracic disease classes).
- Utilizes Swin Transformer (or configurable backbone) with transfer learning via timm.
- Implements advanced evaluation metrics: Precision, Recall, F1-score, and AUC (macro).
- Includes threshold optimization for best validation precision.
- Robust error handling and logging for data loading and preprocessing.
- Modular, extensible codebase for research and experimentation.
- Python 3.8+
- PyTorch for deep learning model implementation.
- timm for state-of-the-art vision backbones (Swin Transformer, etc.).
- torchmetrics for evaluation metrics.
- Pandas and NumPy for data manipulation.
- scikit-learn (optional, for further analysis).
- Pillow for image processing.
.
|-- script.py # Main implementation file (PyTorch pipeline)
|-- testscript.py # Script for testing the model
|-- data.csv # Input dataset (CSV format)
|-- images/ # Directory containing chest X-ray images
|-- output/ # Directory for model checkpoints and logs
|-- README.md # Documentation
|-- requirements.txt # Python dependencies
The project expects a dataset of chest X-ray images annotated with multiple disease labels. Each image is associated with one or more binary labels indicating the presence of specific conditions.
- The CSV file (
data.csv
) should contain:- A column
DicomPath_y
specifying the file path to each X-ray image. - Columns for each disease label (e.g.,
Atelectasis
,Cardiomegaly
, etc.), with values 0 (absent), 1 (present), or -1/NaN (treated as 0). - A
split
column indicating the data split (train
,validate
,test
).
- A column
DicomPath_y | Atelectasis | Cardiomegaly | Pneumonia | ... | split |
---|---|---|---|---|---|
images/img1.png | 1 | 0 | 0 | ... | train |
images/img2.png | 0 | 1 | 1 | ... | validate |
images/img3.png | 0 | 0 | 0 | ... | test |
-
Clone the repository:
git clone https://github.com/your-repo/multi-label-xray-classifier.git cd multi-label-xray-classifier
-
Install required dependencies:
pip install -r requirements.txt
-
Prepare the dataset:
- Place your
data.csv
and images in the appropriate directories as described above.
- Place your
-
Configure settings
Edit theConfig
class inscript.py
to adjust model backbone, batch size, learning rate, etc. -
Run the training and evaluation pipeline:
python script.py
- The script will:
- Load and preprocess the data.
- Train the model with the specified configuration.
- Evaluate on validation and test sets.
- Optimize the classification threshold for best precision.
- Save the best model checkpoint and log results.
- The script will:
-
Check the
output/
directory for saved models and logs.
- Model: Swin Transformer (default) or configurable backbone from
timm
. - Loss: Binary Cross-Entropy with Logits (multi-label).
- Metrics: Precision, Recall, F1-score, AUC (macro, thresholded).
- Threshold Optimization: Automatically searches for the best threshold on the validation set.
- Logging: Detailed logs for each epoch, including metrics and checkpointing.
To evaluate a trained model checkpoint on the test set, use the testscript.py
.
Command (Bash/Linux/macOS):
python d:\personal-projects\HealthCare\DataSets\chest-x-ray-dataset-with-lung-segmentation-1.0.0\testscript.py --checkpoint_path path/to/your/best_model.pth
Command (PowerShell/Windows):
python d:\personal-projects\HealthCare\DataSets\chest-x-ray-dataset-with-lung-segmentation-1.0.0\testscript.py --checkpoint_path path\to\your\best_model.pth
# Or using .\ if in the script's directory:
# python .\testscript.py --checkpoint_path path\to\your\best_model.pth
Explanation:
--checkpoint_path
: (Required) Specify the full path to the saved model checkpoint file (e.g.,output/best_model.pth
oroutput\best_model.pth
generated during training). Use the appropriate path separator for your OS (/
for Linux/macOS,\
for Windows).
Optional Arguments:
You can override settings stored in the checkpoint's configuration by providing additional arguments:
--data_dir
: Path to the directory containing the data and CSV file (e.g.,.
).--csv_file
: Name of the CSV file (e.g.,data.csv
).--batch_size
: Batch size for evaluation (larger sizes might speed up testing).--device
: Device to run evaluation on (e.g.,cuda
,cpu
).--threshold
: Classification threshold to use (if you want to test a specific one instead of the one from the config).--num_workers
: Number of data loading workers.--output_dir
: Directory to save the test log file (defaults totest_output
).
Example with overrides (Bash/Linux/macOS):
python d:\personal-projects\HealthCare\DataSets\chest-x-ray-dataset-with-lung-segmentation-1.0.0\testscript.py \
--checkpoint_path output/best_model.pth \
--data_dir . \
--csv_file data.csv \
--batch_size 32 \
--device cuda \
--output_dir test_run_1
Example with overrides (PowerShell/Windows):
# Single line:
python d:\personal-projects\HealthCare\DataSets\chest-x-ray-dataset-with-lung-segmentation-1.0.0\testscript.py --checkpoint_path output\best_model.pth --data_dir . --csv_file data.csv --batch_size 32 --device cuda --output_dir test_run_1
# Multi-line using backticks (`):
python d:\personal-projects\HealthCare\DataSets\chest-x-ray-dataset-with-lung-segmentation-1.0.0\testscript.py `
--checkpoint_path output\best_model.pth `
--data_dir . `
--csv_file data.csv `
--batch_size 32 `
--device cuda `
--output_dir test_run_1
The script will load the specified checkpoint, prepare the test dataset based on the configuration (or overrides), run the evaluation, and print the performance metrics (Loss, Precision, Recall, F1-Score, AUC, Per-Class F1, Confusion Matrix) to the console and the log file within the specified output directory.
-
Best validation precision and corresponding threshold are reported.
-
Final test set metrics (loss, precision, recall, F1, AUC) are logged.
-
Example (replace with your actual results):
Best validation precision: 0.82 at threshold 0.45 Test set results: Loss: 0.23 Precision: 0.81 Recall: 0.78 F1-score: 0.79 AUC: 0.88
- Add support for additional or custom backbones (e.g., ConvNeXt, EfficientNet).
- Implement data augmentation and advanced regularization.
- Add per-class threshold optimization.
- Integrate explainability (Grad-CAM, saliency maps).
- Support for DICOM image loading and preprocessing.
- Distributed/multi-GPU training.
This project is licensed under the MIT License. See the LICENSE file for details.
- timm for backbone models.
- Dataset contributors for their invaluable work in medical imaging.
- PyTorch and torchmetrics developers for their open-source tools.