Skip to content

TransUNet is a hybrid deep learning model that integrates Transformers with the U-Net architecture for medical image segmentation

License

Notifications You must be signed in to change notification settings

atikul-islam-sajib/TransUNet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

158 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

πŸš€ TransUNet - Unofficial Implementation of Transformer-Based U-Net for Medical Image Segmentation

This is an unofficial implementation of TransUNet:

"Transformers Make Strong Encoders for Medical Image Segmentation"
Jieneng Chen, Yutong Lu, Qihang Yu, et al.

TransUNet is a hybrid deep learning model that combines CNNs (Convolutional Neural Networks) and Transformers to improve segmentation accuracy in medical imaging. This repository provides a fully configurable and easy-to-use implementation of TransUNet.


πŸ“Œ Key Features

βœ… Unofficial Implementation based on the original paper
βœ… Hybrid CNN + Transformer architecture for enhanced segmentation
βœ… Supports multiple optimizers: Adam, AdamW, SGD
βœ… Configurable loss functions: BCE, Focal, Tversky
βœ… Multi-device support: CPU, GPU, Apple M1/M2 (mps)
βœ… Automatic dataset handling, model training, and evaluation
βœ… Saves segmentation masks and best-performing model checkpoints


πŸ“Œ Model Architecture

The TransUNet model follows a two-stage approach:
1️⃣ CNN Encoder - Extracts local spatial features
2️⃣ Transformer Block - Captures global dependencies
3️⃣ Decoder - Combines CNN and Transformer outputs for segmentation

Input Image (H x W x C)
       β”‚
       β–Ό
β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ CNN Encoder β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   (Extracts feature maps)
       β”‚
       β–Ό
-----> Transformer Block ----->   (Captures long-range dependencies)
       β”‚
       β–Ό
β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ Decoder (UpSampling + Skip Connections) β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ
       β”‚
       β–Ό
β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ Segmentation Mask (Output) β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ

πŸ“Œ Installation

1️⃣ Clone the Repository

git clone https://github.com/atikul-islam-sajib/TransUNet.git
cd TransUNet

2️⃣ Install Dependencies

pip install -r requirements.txt

3️⃣ (Optional) Create a Virtual Environment

python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

πŸ“Œ Configuration - config.yaml

Before running training or testing, modify config.yaml to set paths, model parameters, and training options.

# πŸ“Œ Configuration File for TransUNet
# This file defines paths, data settings, model architecture, training parameters, and inference settings.

# πŸ”Ή Paths for storing raw data, processed data, model checkpoints, and outputs
artifacts:
  raw_data_path: "./data/raw/"                         # Directory for raw dataset files
  processed_data_path: "./data/processed/"             # Directory for preprocessed dataset
  files_path: "./artifacts/files/"                     # General storage for generated files
  train_models: "./artifacts/checkpoints/train_models/"  # Directory to store trained models
  best_model: "./artifacts/checkpoints/best_model/"      # Directory for best model checkpoints
  metrics_path: "./artifacts/metrics/"                 # Path to store training/testing metrics
  train_images: "./artifacts/outputs/train_images/"    # Folder to store images generated during training
  test_image: "./artifacts/outputs/test_image/"        # Folder to store predicted test images

# πŸ”Ή Dataset and dataloader settings
dataloader:
  image_path: "./data/raw/dataset.zip"  # Path to the dataset (ZIP format or unzipped folder)
  image_channels: 3                     # Number of image channels (3 for RGB, 1 for grayscale)
  image_size: 128                        # Image resolution (e.g., 128x128)
  batch_size: 8                          # Number of images per batch
  split_size: 0.30                       # Percentage of data used for validation (e.g., 30% validation)

# πŸ”Ή TransUNet Model Configuration
TransUNet:
  nheads: 4              # Number of attention heads in the transformer encoder
  num_layers: 4          # Number of transformer encoder layers
  dim_feedforward: 512   # Hidden layer size in the feedforward network
  dropout: 0.3           # Dropout rate for regularization (higher value prevents overfitting)
  activation: "gelu"     # Activation function ("gelu" or "relu")
  layer_norm_eps: 1e-05  # Epsilon value for layer normalization (stabilizes training)
  bias: False            # Whether to use bias in transformer layers (True/False)

# πŸ”Ή Training Configuration
trainer:
  epochs: 100            # Number of epochs for training
  lr: 0.0001             # Learning rate for optimization
  optimizer: "AdamW"     # Selected optimizer: "Adam", "AdamW", or "SGD"

  # Optimizer configurations (fine-tuning parameters)
  optimizers:
    Adam: 
      beta1: 0.9
      beta2: 0.999
      weight_decay: 0.0001
    SGD: 
      momentum: 0.95
      weight_decay: 0.0
    AdamW:
      beta1: 0.9
      beta2: 0.999
      weight_decay: 0.0001

  # Loss function settings
  loss: 
    type: "bce"           # Type of loss function: "bce", "focal", or "tversky"
    loss_smooth: 1e-06    # Smoothing factor for loss computation (prevents overconfidence)
    alpha_focal: 0.75     # Alpha value for focal loss (balances class distribution)
    gamma_focal: 2        # Gamma value for focal loss (higher values focus on hard examples)
    alpha_tversky: 0.75   # Alpha parameter for Tversky loss (controls false positives)
    beta_tversky: 0.5     # Beta parameter for Tversky loss (controls false negatives)

  # Regularization settings (helps prevent overfitting)
  l1_regularization: False       # Enable L1 regularization (True/False)
  elastic_net_regularization: False  # Enable elastic net regularization (True/False)

  verbose: True       # Display progress and save images during training (True/False)
  device: "cuda"      # Device for training: "cuda" (GPU), "mps" (Mac M1/M2), or "cpu"

# πŸ”Ή Testing Configuration
tester:
  dataset: "test"  # Dataset used for testing
  device: "cuda"   # Device to use for testing: "cuda", "mps", or "cpu"

# πŸ”Ή Inference Configuration
inference:
  image: "./artifacts/data/processed/sample.jpg"  # Path to the image used for inference

πŸ“Œ Explanation of Key Sections

Section Description
artifacts Defines storage paths for datasets, model checkpoints, and outputs.
dataloader Specifies dataset path, image properties, batch size, and validation split.
TransUNet Defines model architecture, including Transformer layers and activation functions.
trainer Configures training parameters like epochs, optimizer, and loss functions.
tester Specifies dataset and device settings for evaluation.
inference Defines the path to an image for making predictions.

πŸ“Œ Running Training & Testing

Process Command Description
Train Model python src/cli.py --train Starts model training using config.yaml
Test Model python src/cli.py --test Runs evaluation on test data
Change Optimizer Edit config.yaml Supports "SGD", "Adam", "AdamW"

πŸ“Œ Viewing Results

Process Saved Location Description
Model Checkpoints ./artifacts/checkpoints/train_models/ Stores model checkpoints during training.
Best Model ./artifacts/checkpoints/best_model/ Saves the best-performing model based on validation metrics.
Test Predictions ./artifacts/outputs/test_image/ Stores predicted segmentation masks from test data.

πŸ“Œ TransUNet Workflow

Step Process Description
1️⃣ Load Dataset Load and preprocess the dataset for training.
2️⃣ Train Model Train TransUNet using the dataset and save checkpoints.
3️⃣ Evaluate Model Validate model performance on test data.
4️⃣ Generate Predictions Apply the trained model on test images and generate segmentation masks.

πŸ“Œ Citation

If you use this repository, please consider citing the original TransUNet paper:

@article{chen2021transunet,
  title={TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation},
  author={Chen, Jieneng and Lu, Yutong and Yu, Qihang and others},
  journal={arXiv preprint arXiv:2102.04306},
  year={2021}
}

πŸ“Œ License

This project is open-source and available under the MIT License.
🚨 Note: This is an unofficial implementation and is not affiliated with the original authors.

About

TransUNet is a hybrid deep learning model that integrates Transformers with the U-Net architecture for medical image segmentation

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published