A PyTorch implementation of ESRGAN for 4× image super-resolution, trained on DIV2K, Flickr2K, and ISR datasets.
This project implements ESRGAN (Enhanced Super-Resolution GAN) with the RRDBNet (Residual-in-Residual Dense Block Network) architecture. The model performs 4× upscaling of low-resolution images while preserving fine details and generating perceptually realistic textures.
- Two-phase training: PSNR-oriented pretraining followed by GAN-based perceptual refinement
- Relativistic Average GAN (RaGAN) loss for stable adversarial training
- Perceptual loss using VGG19 features
- Spectral normalization in discriminator for training stability
- Mixed precision training with automatic gradient scaling
- Progressive GAN weight warmup to prevent early instability
This project uses uv for dependency management.
# Clone the repository
git clone https://github.com/aidendorian/4x-Upscaler-ESRGAN
cd 4x-Upscaler-ESRGAN
# Install dependencies with uv
uv syncUse the provided Jupyter notebook for easy inference:
- upscale.ipynb
Or run programmatically:
import torch
from PIL import Image
from src.esrgan import RRDBNet
import torchvision.transforms.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model
generator = RRDBNet().to(device)
generator.load_state_dict(torch.load("models/ESRGAN.pth", map_location=device))
generator.eval()
# Upscale image
lr_img = Image.open("input.jpg").convert("RGB")
lr_tensor = F.to_tensor(lr_img).unsqueeze(0).to(device)
with torch.no_grad():
sr_tensor = generator(lr_tensor)
sr_img = F.to_pil_image(sr_tensor.squeeze(0).cpu())
sr_img.save("output.png")Important: Always use models/ESRGAN.pth (the fully trained model), not ESRGAN_PSNR.pth (which is only Phase 1).
- 23 RRDB blocks (Residual-in-Residual Dense Blocks)
- Each RRDB contains 3 Dense Residual Blocks (DRBs)
- Dense connections within blocks for better gradient flow
- Residual scaling factor of 0.2 for training stability
- PixelShuffle upsampling (2×2) for 4× total upscaling
- ~16.7M parameters
- VGG-style architecture with Spectral Normalization
- PatchGAN output (predicts real/fake for image patches)
- Progressive downsampling: 4 stride-2 convolutions
- No batch normalization (spectral norm provides regularization)
data/
├── DIV2K/
│ ├── hr_images/ # 800 high-res images (bicubic downsampled)
│ └── lr_images/ # Corresponding low-res images
├── Flickr2K/
│ ├── hr_images/ # 2,650 images (mixed degradation)
│ └── lr_images/
├── ISR/
│ ├── hr_images/ # 1,254 images (bicubic downsampled)
│ └── lr_images/
└── Validation/
└── hr_images/ # Validation set
Total training images: 4,704
- Objective: Pixel-wise reconstruction (L1 loss only)
- Learning rate: 2e-4
- Target: ~26-27 dB validation PSNR
- Purpose: Establish strong baseline for content fidelity
- Losses: Pixel (1.0) + Perceptual (1.0) + GAN (0.005) + Color (0.05)
- GAN weight warmup: Linearly increases over first 20 epochs
- Learning rates: Generator (2e-4), Discriminator (2e-4)
- LR decay: 0.5× at 50%, 75%, 90% of GAN epochs
- Purpose: Add perceptual quality while preserving PSNR
cd src
uv run train.pyKey hyperparameters in train.py:
BATCH_SIZE = 10
PSNR_EPOCHS = 100
GAN_EPOCHS = 100
LOSS_WEIGHTS = {
'pixel': 1.0,
'perceptual': 1.0,
'gan': 0.005, # keeps GAN subtle
'color': 0.05,
}- Validation PSNR: Should reach 26-27 dB by epoch 100
- Pixel loss: Steadily decreasing from ~0.08 → ~0.04
Monitor these metrics every 5 epochs:
| Metric | Healthy Range | Warning Signs |
|---|---|---|
| Validation PSNR | 25.5-27.0 dB | Dropping below 25.0 dB |
| D Loss | 0.55-0.75 | <0.4 (too weak) or >0.85 (too strong) |
| GAN Loss | 0.6-1.2 | <0.5 (G winning) or >1.5 (D winning) |
| Pixel Loss | 0.045-0.055 | Increasing over time |
Cause: GAN weight too high, perceptual losses overwhelming reconstruction
Solution:
# In train.py, reduce GAN weight
LOSS_WEIGHTS = {
'gan': 0.002, # Reduced from 0.005
}Symptoms: D loss dropping, GAN loss spiking (>1.5)
Cause: Discriminator too weak to provide useful gradients
Solution:
# Boost discriminator learning rate temporarily
d_opt = Adam(discriminator.parameters(), lr=3e-4) # Was 2e-4
g_opt = Adam(generator.parameters(), lr=1e-4) # Lower generatorOr reload discriminator from earlier checkpoint:
CHECKPOINT_PATH_D = 'checkpoints/discriminator_epoch_140.pth' # Healthy epochSymptoms: GAN loss too low, discriminator winning easily
Solution:
# Lower discriminator LR
d_opt = Adam(discriminator.parameters(), lr=5e-5) # Was 2e-4Symptoms: All losses flat for 20+ epochs
Solution:
- Reduce learning rates by 0.5×
- Increase GAN weight slightly (0.005 → 0.008)
- Check validation samples visually for quality
- Validation PSNR: 25.5-26.0 dB
- Perceptual quality: Sharp edges, realistic textures
- Color accuracy: Minimal color shift from original
ESRGAN/
├── src/
│ ├── __init__.py
│ ├── train.py # Main training loop
│ ├── esrgan.py # Generator & discriminator architectures
│ ├── losses.py # Loss functions (pixel, perceptual, GAN)
│ ├── dataloader.py # Dataset loading & augmentation
│ └── checkpoint.py # Save/load training state
├── models/
│ ├── ESRGAN.pth # Final trained model (USE THIS)
│ └── ESRGAN_PSNR.pth # Phase 1 checkpoint (reference only)
├── checkpoints/ # Training checkpoints (auto-saved)
├── data/ # Training datasets
├── images/ # Example images & outputs
├── val_output/ # Validation samples during training
├── demo_1.png # Demo: Original vs Upscaled (zoomed)
├── demo_2.png # Demo: Original vs Upscaled (zoomed)
├── demo_3.png # Demo: Original vs Upscaled (zoomed)
├── upscale.ipynb # Inference notebook
├── pyproject.toml # uv dependencies
└── README.md
In dataloader.py, training images receive:
- JPEG compression (quality: 80-95, probability: 30-60%)
- Gaussian noise (σ: 0.002-0.007, probability: 30-50%)
- Geometric augmentations: Horizontal/vertical flips, 90° rotations
Validation images receive light degradation (deterministic per image) to match training distribution.
- Generator: Gradient clipping at norm 10.0
- Discriminator: Gradient clipping at norm 1.0 (tighter to prevent collapse)
- Mixed precision: AMP with GradScaler for faster training
RESUME_FROM_CHECKPOINT = True
CHECKPOINT_PATH_G = 'checkpoints/generator_epoch_140.pth'
CHECKPOINT_PATH_D = 'checkpoints/discriminator_epoch_140.pth'Checkpoints include:
- Model weights
- Optimizer state
- GradScaler state
- RNG state (for reproducibility)
- Iteration count
BATCH_SIZE = 6 # Reduce from 10- Ensure
num_workers=4in dataloaders - Use
pin_memory=Truefor GPU training - Enable
persistent_workers=True
- Check dataset HR/LR correspondence (should be 4× size difference)
- Verify images are properly preprocessed (RGB, correct sizes)
- Increase PSNR epochs to 150 if needed
- Reduce GAN weight (0.005 → 0.003)
- Load earlier checkpoint with better PSNR
- Increase pixel loss weight (1.0 → 1.5)

