ARGformer is a transformer encoder based on ModernBERT for Ancestral Recombination Graph (ARG) data. It uses the FlexBERT architecture with YAML-based configuration.
The codebase builds upon MosaicBERT and the unmerged fork with Flash Attention 2 under Apache 2.0 license.
For ModernBERT details, see the release blog post and arXiv preprint.
conda env create -f environment.yaml
conda activate bert24
pip install "flash_attn==2.6.3" --no-build-isolationFor H100 GPUs, optionally install Flash Attention 3:
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/hopper
python setup.py installARGformer supports:
- Pretraining: Masked language modeling on ARG sequences
- Contrastive Learning: Fine-tuning for retrieval and similarity tasks
- Embeddings: Extracting embeddings for downstream analysis
- Retrieval: Finding similar sequences in large corpora
ARG data structure:
/path/to/arg/data/
├── train/
│ ├── tokenized_train_sequences_and_vocab.pkl
│ └── labels.pkl # Optional: for contrastive learning
└── val/
├── tokenized_val_sequences_and_vocab.pkl
└── labels.pkl # Optional: for contrastive learning
The ARGDataset class supports pretokenized sequences with vocabulary mappings for node IDs and special tokens ([PAD], [CLS], [SEP]).
Extract sequences from tree files using src/data/prepare_data_pretrain.py:
python src/data/prepare_data_pretrain.pyEdit the script to configure input paths, output directory, and train/val split.
Configure yamls/mlm.yaml with dataset paths and model parameters, then run:
composer main.py yamls/mlm.yamlConfigure yamls/contrastive.yaml with pretrained checkpoint path and run:
python sequence_contrastive.py yamls/contrastive.yamlpython embeddings.py [arguments]See the script for usage examples.
python retrieve.py [arguments]See the script for usage examples.
Training uses composer with YAML configuration files in yamls/:
mlm.yaml: Pretraining configurationcontrastive.yaml: Contrastive learning configuration
Key configuration sections:
model: Model architecture and checkpoint pathstrain_loader/eval_loader: Dataset paths and data loading settingsoptimizer/scheduler: Training hyperparametersloggers: WandB logging configuration