|
1 | | -# Spiking Decision Transformer |
| 1 | +<div align="center"> |
2 | 2 |
|
| 3 | +# Neuromorphic Decision Transformer (SNN-DT) |
| 4 | + |
| 5 | +[](https://opensource.org/licenses/MIT) |
| 6 | +[](https://www.python.org/downloads/) |
| 7 | +[](https://pytorch.org/) |
| 8 | +[](https://arxiv.org/abs/2508.21505) |
| 9 | + |
| 10 | +**Local Plasticity, Phase-Coding, and Dendritic Routing for Low-Power Sequence Control** |
| 11 | + |
| 12 | +</div> |
| 13 | + |
| 14 | +## Abstract |
| 15 | + |
| 16 | +This repository contains the official PyTorch implementation of the **Spiking Decision Transformer (SNN-DT)**, as presented in the paper *"Spiking Decision Transformers: Local Plasticity, Phase-Coding, and Dendritic Routing for Low-Power Sequence Control"* (Pandey & Biswas, 2025). |
| 17 | + |
| 18 | +The SNN-DT architecture bridges the gap between the sequential modeling capabilities of Transformers and the energy efficiency of Spiking Neural Networks (SNNs). By embedding **Leaky Integrate-and-Fire (LIF)** neurons within the self-attention mechanism and utilizing accurate **STDP-inspired local plasticity**, this model achieves state-of-the-art performance on continuous control tasks while reducing energy consumption by orders of magnitude compared to traditional ANN-based Decision Transformers. |
| 19 | + |
| 20 | +<div align="center"> |
| 21 | + <img src="visualizations/model_architecture.png" alt="SNN-DT Architecture" width="800"> |
| 22 | +</div> |
| 23 | +*(Note: Visualizations available in the `visualizations/` directory)* |
| 24 | + |
| 25 | +## Key Features |
| 26 | + |
| 27 | +- **Neuromorphic Efficiency**: Replaces standard activation functions with temporal spike-based logic, significantly reducing computational overhead suitable for edge deployment. |
| 28 | +- **Phase-Coded Positional Encoding**: A biologically plausible method for encoding sequence order using spike timing phases. |
| 29 | +- **Dendritic Routing**: Efficient information routing mechanism mimicking biological dendritic trees. |
| 30 | +- **Three-Factor Local Plasticity**: Implements STDP-like learning rules for robust weight updates without heavy backpropagation costs during inference-time adaptation. |
| 31 | +- **Standard Gym Benchmarks**: Evaluated on classic control tasks: `CartPole-v1`, `Pendulum-v1`, `MountainCar-v0`, and `Acrobot-v1`. |
| 32 | + |
| 33 | +## Installation |
| 34 | + |
| 35 | +System requirements: Linux/Windows, Python 3.8+, CUDA-enabled GPU (recommended). |
| 36 | + |
| 37 | +```bash |
| 38 | +# Clone the repository |
| 39 | +git clone https://github.com/Vishal-sys-code/neuromorphic_decision_transformer.git |
| 40 | +cd neuromorphic_decision_transformer |
| 41 | + |
| 42 | +# Create a virtual environment (optional but recommended) |
| 43 | +python -m venv venv |
| 44 | +source venv/bin/activate # On Windows: venv\Scripts\activate |
| 45 | + |
| 46 | +# Install dependencies |
| 47 | +pip install -r requirements.txt |
| 48 | +``` |
| 49 | + |
| 50 | +## Usage |
| 51 | + |
| 52 | +### Training |
| 53 | +To train the SNN-DT model on a specific environment (e.g., `Pendulum-v1`), use the provided training script. The training pipeline handles data generation, preprocessing, and model optimization. |
| 54 | + |
| 55 | +```bash |
| 56 | +# Run training for Pendulum-v1 |
| 57 | +python snn-dt/scripts/train.py --model snn_dt --env "Pendulum-v1" --save-dir "results/snn_dt_pendulum" |
| 58 | +``` |
| 59 | + |
| 60 | +To run the full suite of experiments across all environments: |
| 61 | +```bash |
| 62 | +./run_all_experiments.sh |
| 63 | +``` |
| 64 | + |
| 65 | +### Evaluation |
| 66 | +Evaluate a pre-trained checkpoint to measure return, spike counts, latency, and estimated energy consumption. |
| 67 | + |
| 68 | +```bash |
| 69 | +python eval_snn_dt.py \ |
| 70 | + --env "Pendulum-v1" \ |
| 71 | + --checkpoint_path "results/snn_dt_pendulum/best_model.pt" \ |
| 72 | + --target_return -200 \ |
| 73 | + --episodes 50 |
| 74 | +``` |
| 75 | + |
| 76 | +**Key Arguments:** |
| 77 | +- `--context_len`: Context length ($K$) for the transformer (default: 20). |
| 78 | +- `--per_spike_energy`: Estimated energy per spike in Joules (default: 4.6pJ for 45nm process). |
| 79 | + |
| 80 | +## Project Structure |
| 81 | + |
| 82 | +``` |
| 83 | +neuromorphic_decision_transformer/ |
| 84 | +├── configs/ # YAML configuration files for experiments |
| 85 | +├── snn-dt/ # Core SNN-DT source code |
| 86 | +│ ├── src/ # Model definitions and extensive util libraries |
| 87 | +│ └── scripts/ # Training and utility scripts |
| 88 | +├── demos/ # Demonstration notebooks and videos |
| 89 | +├── eval_snn_dt.py # Standalone evaluation script |
| 90 | +├── requirements.txt # Python dependencies |
| 91 | +└── run_all_experiments.sh # Batch experiment runner |
| 92 | +``` |
| 93 | + |
| 94 | +## Citation |
| 95 | + |
| 96 | +If you use this code or find our work helpful, please cite our paper: |
| 97 | + |
| 98 | +```bibtex |
| 99 | +@article{pandey2025spiking, |
| 100 | + title={Spiking Decision Transformers: Local Plasticity, Phase-Coding, and Dendritic Routing for Low-Power Sequence Control}, |
| 101 | + author={Pandey, Vishal and Biswas, Debasmita}, |
| 102 | + journal={arXiv preprint arXiv:2508.21505}, |
| 103 | + year={2025} |
| 104 | +} |
| 105 | +``` |
| 106 | + |
| 107 | +## License |
| 108 | + |
| 109 | +This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. |
0 commit comments