Mark Schöne1,2 *, Babak Rahmani2 *, Heiner Kremer2, Fabian Falck2, Hitesh Ballani2, Jannes Gladrow2
1 TU Dresden, Germany, 2 Microsoft Research, Cambridge, UK
(*) Equal contribution. (
🎉🚀 We are delighted to announce that our paper was spotlighted at ICML 2025! 🚀🎉
State-space models (SSMs) and transformers dominate the language modeling landscape. However, they are constrained to a lower computational complexity than classical recurrent neural networks (RNNs), limiting their expressivity. In contrast, RNNs lack parallelization during training, raising fundamental questions about the trade off between parallelization and expressivity. We propose implicit SSMs, which iterate a transformation until convergence to a fixed point. Theoretically, we show that implicit SSMs implement the non-linear state-transitions of RNNs. Empirically, we find that only approximate fixed-point convergence suffices, enabling the design of a scalable training curriculum that largely retains parallelization, with full convergence required only for a small subset of tokens. Our approach demonstrates superior state-tracking capabilities on regular languages, surpassing transformers and SSMs. We further scale implicit SSMs to natural language reasoning tasks and pretraining of large-scale language models up to 1.3B parameters on 207B tokens - representing, to our knowledge, the largest implicit model trained to date. Notably, our implicit models outperform their explicit counterparts on standard benchmarks.
Requirements:
mamba_ssmandcausal_conv1d
Install this package
pip install .
The code allows for integration with the HuggingFace Platform.
We provide local configuration files that can be loaded with AutoConfig
from transformers import AutoConfig, AutoModel
import implicit_llm
cfg = AutoConfig.from_pretrained('hf_models/llama3-1.3b-implicit')
model = AutoModel.from_config(cfg)We provide a simple training script based on the huggingface Trainer. First, generate the dataset following the instructions. Then, train your models with
python -m examples.state_tracking \
--model_name hf_models/mamba2-state-tracking-implicit \
--train_dataset /path/to/data/train_A5_L256_P090.bin \
--eval_dataset /path/to/data/test_A5_L256_P050.bin \
--test_dataset /path/to/data/test_A5_L256_P050.bin The script works for arbitrary models from the huggingface hub. Feel free to train your favorite models!
To evaluate a trained model use the --eval flag and point --model_name to the trained model checkpoint.
E.g. run evaluation on the test set with 1024 tokens
python -m examples.state_tracking \
--model_name path/to/trained/model/checkpoint \
--train_dataset /path/to/data/train_A5_L256_P090.bin \
--eval_dataset /path/to/data/test_A5_L256_P050.bin \
--test_dataset /path/to/data/test_A5_L256_P050.bin
--eval By default, training always used the simultaneous fixed point iteration, while generation always uses the sequential fixed point iteration.
We provide examples of evaluating a model in the sequential mode, e.g. to reproduce Figure 2C, in tests/test_evaluation.py and in examples/state_tracking.py.
The state tracking example code uses the simultaneous mode for validation during training.
A sequential pass is done at the end of training on the test set.
ValueError: The checkpoint you are trying to load has model type `implicit_mamba2` but Transformers does not recognize this architecture.
--> Just import implicit_llm to register the implicit models with the HF library, or `
from implicit_llm import register_implicit_causal_lm
register_implicit_causal_lm()
This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
When you submit a pull request, a CLA bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.
This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact [email protected] with any additional questions or comments.
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft trademarks or logos is subject to and must follow Microsoft's Trademark & Brand Guidelines. Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. Any use of third-party trademarks or logos are subject to those third-party's policies.
@inproceedings{
schone2025implicit,
title={Implicit Language Models are {RNN}s: Balancing Parallelization and Expressivity},
author={Mark Sch{\"o}ne and Babak Rahmani and Heiner Kremer and Fabian Falck and Hitesh Ballani and Jannes Gladrow},
booktitle={Forty-second International Conference on Machine Learning},
year={2025},
url={https://openreview.net/forum?id=5EbiopWH6e}
}