This repository contains the code for the paper Synthetic Continued Pretraining.
This codebase implements the entire pipeline for synthetic continued pretraining using the EntiGraph synthetic data generator. It includes:
- Code for generating synthetic data with EntiGraph
- Scripts for continued pretraining with Llama 3 8B
- Evaluation tools for the continually pretrained model
- Instruction tuning process
- Interactive chatbot based on the instruction-tuned model
- Installation
- EntiGraph Synthetic Continued Pretraining
- Instruction Tuning on Continued Pretrained Model
- Citation
git clone https://github.com/ZitongYang/Synthetic_Continued_Pretraining.git
cd Synthetic_Continued_Pretraining
pip install -r requirements.txt
huggingface-cli login --token <huggingface token>; # required, you need this to access Llama 3 pretrained weights
wandb login <weights and bias token>; # optional, ignore if you don't want to log your training processOur experiments use the QuALITY dataset as the source documents.
- Set your OpenAI API key in
data/dataset/openai.key. - To run the EntiGraph procedure for the
i-th article usinggpt-4-turbo:
python data/entigraph.py iThe resulting synthetic data will be saved in data/dataset/raw/quality_entigraph_gpt-4-turbo/.
Tokenize the EntiGraph synthetic data:
mkdir -p data/dataset/bins/
python data/tokenize_entigraph.pyThis will save the resulting binary files in data/dataset/bins/quality_all-graphgpt-4-turbo.bin.
Download and tokenize 1B tokens of RedPajama dataset as replay data:
python data/tokenize_redpj.pyThis will save two binary files:
data/dataset/bins/togethercomputer_RedPajama_Data_1T_Sample_None_train.bindata/dataset/bins/togethercomputer_RedPajama_Data_1T_Sample_None_test.bin
To inspect the synthetic data generated:
python data/cptdata.pyTo perform continued pretraining on Llama 3 8B using the EntiGraph synthetic data:
chmod 777 scripts/train.sh
./scripts/train.sh \
--lr 5e-06 \
--rr 0.1 \
--epochs 2 \
--bs 16 \
--wd 0.01 \
--warmup 0.05 \
--task_name qualityArguments:
--lr: Peak learning rate--rr: RedPajama replay rate--epochs: Total epochs to run--bs: Batch size--wd: Weight decay factor--task_name: Dataset choice (qualityfor EntiGraph synthetic data,instructfor UltraChat instruction tuning data)
The resulting checkpoint will be saved under ckpts/quality-lr5e-06-rr0.1-epochs2-bs16-wd0.01-warmup0.05-MetaLlama38B.
To evaluate on the QuALITY QA set:
python evaluation.py --model_path=ckpts/quality-lr5e-06-rr0.1-epochs2-bs16-wd0.01-warmup0.05-MetaLlama38BThe output will be stored in out/qualityqa-quality-lr5e-06-rr0.1-epochs2-bs16-wd0.01-warmup0.05-MetaLlama38B.json.
To parse the output into accuracy metrics, refer to notebooks/nb_qa_eval.ipynb.
We use the UltraChat dataset and Llama 3.1 Instruct chat template:
python data/tokenize_instruct.pyThis will save the instruction tuning data in data/dataset/bins/ultrachat_train.bin and data/dataset/bins/ultrachat_test.bin.
To perform instruction tuning on the continually pretrained model:
./scripts/train.sh \
--lr 5e-06 \
--rr 0.1 \
--epochs 2 \
--bs 128 \
--wd 0.01 \
--warmup 0.05 \
--task_name instruct \
--model_name ckpts/quality-lr5e-06-rr0.1-epochs2-bs16-wd0.01-warmup0.05-MetaLlama38BThe checkpoint will be saved in ckpts/instruct-lr5e-06-rr0.1-epochs2-bs128-wd0.01-warmup0.05-qualitylr5e06rr0.1epochs2bs16wd0.01warmup0.05MetaLlama38B.
To launch an interactive session with the instruction-tuned EntiGraph model:
python interactive.pyYou can ask questions about QuALITY articles (e.g., Tell me about the article "defining decay down".).
If you use this code in your research, please cite our paper:
@misc{yang2024syntheticcontinuedpretraining,
title={Synthetic continued pretraining},
author={Zitong Yang and Neil Band and Shuangping Li and Emmanuel Candès and Tatsunori Hashimoto},
year={2024},
eprint={2409.07431},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2409.07431},
}