Skip to content

pronzzz/super-mario-agent

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

13 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

Super Mario RL Agent ๐ŸŽฎ

A state-of-the-art Deep Reinforcement Learning agent that learns to play Super Mario Bros using Rainbow DQN with advanced techniques including Spatial Transformer Networks and multi-branch architecture.

Python PyTorch License

๐ŸŒŸ Features

This implementation includes state-of-the-art RL techniques that address common DQN problems:

Complete Rainbow DQN Implementation

All 6 Rainbow DQN components implemented:

Component Purpose Benefit
Double DQN Separate action selection/evaluation Reduces overestimation bias
Dueling Architecture Separate value & advantage streams Better state value estimation
Distributional RL (C51) Model full value distribution More stable learning
Noisy Networks Learnable exploration noise Better exploration than ฮต-greedy
Prioritized Replay Sample important transitions Improved sample efficiency
Multi-step Returns n-step bootstrapping (n=3) Faster credit assignment

Advanced Architecture Features

  • ๐ŸŽฏ Spatial Transformer Network (STN): Learns to focus on relevant screen regions (enemies, gaps, power-ups)
  • ๐Ÿ”— Multi-Branch Architecture: Combines visual features (CNN) with action history (MLP)
  • ๐Ÿง  Attention Mechanism: Adaptive spatial transformations for better feature extraction

๐Ÿ—๏ธ Architecture

Input: 4 stacked frames (84ร—84) + 8 action history
    โ†“
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ VISUAL BRANCH                           โ”‚
โ”‚  STN โ†’ CNN (Nature DQN) โ†’ Features     โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
    โ†“
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ ACTION HISTORY BRANCH                   โ”‚
โ”‚  Embedding โ†’ MLP โ†’ Features            โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
    โ†“
    Fusion (Concatenate)
    โ†“
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ DUELING HEADS (Noisy Layers)           โ”‚
โ”‚  Value Stream  โ†’ V(s)                  โ”‚
โ”‚  Advantage Stream โ†’ A(s,a)             โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
    โ†“
Output: Q-value distribution (7 actions ร— 51 atoms)

Network Details

Visual Branch:

  • Spatial Transformer Network: 2D affine transformations
  • CNN Backbone: Nature DQN architecture
    • Conv2d(4โ†’32, kernel=8, stride=4)
    • Conv2d(32โ†’64, kernel=4, stride=2)
    • Conv2d(64โ†’64, kernel=3, stride=1)

Action History Branch:

  • Embedding layer (7 actions โ†’ 32 dims)
  • MLP (256โ†’128โ†’128)

Dueling Heads with Noisy Layers:

  • Value Stream: NoisyLinear(512) โ†’ NoisyLinear(51)
  • Advantage Stream: NoisyLinear(512) โ†’ NoisyLinear(7ร—51)
  • Combine: Q(s,a) = V(s) + (A(s,a) - mean(A))

๐Ÿ“ฆ Installation

# Clone the repository
git clone https://github.com/pronzzz/super-mario-agent.git
cd super-mario-agent

# Create virtual environment (recommended)
python3 -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

# Install dependencies
pip install -r requirements.txt

๐Ÿš€ Usage

Training

python -m src.train

This will:

  • Train the agent on Super Mario Bros World 1-1
  • Save checkpoints every 500 episodes to ./mario_runs/<timestamp>/checkpoints/
  • Log metrics to TensorBoard

Monitor Training with TensorBoard

tensorboard --logdir=mario_runs

Open http://localhost:6006 to view:

  • Episode rewards (cumulative and moving average)
  • Loss curves
  • Q-value estimates
  • Episode lengths

Play with Trained Agent

python play.py --checkpoint mario_runs/<timestamp>/checkpoints/mario_net_final.pth

Watch the agent play in real-time with visualization!

Verify Installation

python verify.py

Runs unit tests for all components.

โš™๏ธ Configuration

Edit src/train.py to modify training parameters:

config = {
    'num_episodes': 10000,     # Total episodes
    'save_interval': 500,       # Checkpoint frequency
    'log_interval': 10,         # Console log frequency
    'device': 'cuda',           # 'cuda' or 'cpu'
    'save_dir': './mario_runs'  # Save directory
}

Key hyperparameters (in src/agent.py):

Parameter Value Description
Learning Rate 2.5e-4 Adam optimizer learning rate
Discount Factor (ฮณ) 0.99 Future reward discount
Batch Size 32 Samples per training step
Replay Buffer 100,000 Maximum stored transitions
Target Update 10,000 steps Sync frequency for target network
Learning Starts 50,000 steps Initial exploration period
Multi-step (n) 3 N-step return horizon
C51 Atoms 51 Distributional RL bins
Value Range [-10, 10] Distribution support
PER Alpha 0.6 Prioritization exponent
PER Beta 0.4โ†’1.0 Importance sampling annealing

๐Ÿ“Š Expected Results

Training progression on World 1-1:

Episodes Behavior Avg Reward
0-100 Random exploration ~100-200
100-500 Learning to run right ~300-600
500-1000 Jumping over gaps ~800-1500
1000-2000 Consistent progress ~1500-2500
2000-5000 Level completion ~2500-3000

Convergence: Reliable World 1-1 completion around 3000-5000 episodes (~12-20 hours on CPU, ~3-6 hours on GPU).

๐Ÿ—‚๏ธ Project Structure

super-mario-agent/
โ”œโ”€โ”€ src/
โ”‚   โ”œโ”€โ”€ __init__.py           # Package init
โ”‚   โ”œโ”€โ”€ wrappers.py           # Environment preprocessing
โ”‚   โ”œโ”€โ”€ model.py              # Rainbow DQN architecture
โ”‚   โ”œโ”€โ”€ replay.py             # Prioritized replay buffer
โ”‚   โ”œโ”€โ”€ agent.py              # Agent logic and learning
โ”‚   โ””โ”€โ”€ train.py              # Training loop
โ”œโ”€โ”€ requirements.txt          # Dependencies
โ”œโ”€โ”€ verify.py                 # Verification script
โ”œโ”€โ”€ play.py                   # Play with trained agent
โ”œโ”€โ”€ visualize.py              # Training visualization GUI
โ””โ”€โ”€ README.md                 # This file

๐Ÿ” Implementation Details

Environment Preprocessing

  1. SkipFrame: Repeat action for 4 frames (reduces computation by 4ร—)
  2. GrayScaleResize: RGB โ†’ Grayscale + resize to 84ร—84
  3. FrameStack: Stack last 4 frames (captures motion)
  4. ActionHistoryWrapper: Track last 8 actions (custom wrapper)

Action Space: 7 simple movements

  • NOOP, Right, Right+A, Right+B, Right+A+B, A, Left

Distributional RL (C51)

Uses categorical cross-entropy instead of MSE:

# Project target distribution onto 51-atom support
for each atom j:
    Tz = r + ฮณโฟ * atom[j]
    Tz = clamp(Tz, v_min, v_max)
    
    # Linear interpolation to neighboring atoms
    b = (Tz - v_min) / ฮ”z
    l, u = floor(b), ceil(b)
    
    # Distribute probability mass
    target_dist[l] += next_dist[j] * (u - b)
    target_dist[u] += next_dist[j] * (b - l)

loss = -ฮฃ target_dist * log(current_dist)

Prioritized Experience Replay

  • SumTree data structure for O(log N) sampling
  • Priorities = |TD error| + ฮต
  • Importance sampling weights: w = (N * P(i))^(-ฮฒ)
  • Beta annealing from 0.4 to 1.0 over 100k frames

Noisy Networks

Factorized Gaussian noise (Fortunato et al.):

weight = weight_ฮผ + weight_ฯƒ โŠ™ ฮต
ฮต = sign(x) * โˆš|x|, where x ~ N(0,1)
  • Training: noisy weights (exploration)
  • Evaluation: mean weights (deterministic)

๐Ÿ“ˆ Visualization

The project includes two visualization tools:

1. TensorBoard (Real-time Training Metrics)

tensorboard --logdir=mario_runs

2. GUI Visualizer (Agent Performance)

python visualize.py --checkpoint <path_to_checkpoint>

Features:

  • Live gameplay rendering
  • Real-time Q-value distribution visualization
  • Action history timeline
  • Performance metrics dashboard

๐ŸŽฏ What Makes This Special

Addresses Common DQN Problems

  1. Sample Inefficiency โ†’ Rainbow components + n-step + PER
  2. Overestimation Bias โ†’ Double DQN + distributional RL
  3. Poor Exploration โ†’ Noisy networks (no ฮต-greedy)
  4. Training Instability โ†’ Target network + gradient clipping + dueling
  5. Limited Attention โ†’ Spatial Transformer Network

Production-Ready Code

  • โœ… Comprehensive documentation
  • โœ… Type hints and docstrings
  • โœ… Modular design
  • โœ… TensorBoard integration
  • โœ… Checkpoint management
  • โœ… Verification tests

๐Ÿ”ฎ Future Enhancements

Potential extensions:

  1. Curiosity-Driven Exploration: Intrinsic rewards (RND, ICM)
  2. Data Augmentation: Random crop, color jitter (RAD)
  3. Distributed Training: IMPALA/Ape-X for parallel actors
  4. Generalization: Multi-level training and transfer learning
  5. Recurrent Networks: LSTM/GRU for partial observability
  6. Imitation Learning: Pretrain on human demonstrations

๐Ÿ“š References

This implementation is based on:

  1. Rainbow DQN: Hessel et al., 2018
  2. Spatial Transformer Networks: Jaderberg et al., 2015
  3. Human-level control through deep RL: Mnih et al., 2015
  4. Prioritized Experience Replay: Schaul et al., 2016
  5. Noisy Networks: Fortunato et al., 2018

๐Ÿ“ License

MIT License - see LICENSE file for details

๐Ÿค Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

๐Ÿ’ก Acknowledgments

  • OpenAI Gym and gym-super-mario-bros for the environment
  • PyTorch team for the deep learning framework
  • DeepMind for Rainbow DQN research

Built with โค๏ธ using PyTorch and reinforcement learning

About

A reinforcement learning project featuring an intelligent agent trained to navigate and complete levels in Super Mario Bros. using deep Q-learning and computer vision techniques.

Topics

Resources

License

Stars

Watchers

Forks

Contributors

Languages