Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions projects/sd-ra-it/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Training LLMs on Self-generated Demonstrations

Scripts and configs for replicating the experiments from ["Post-training an LLM for RAG? Train on Self-Generated Demonstrations"](https://arxiv.org/abs/2502.10596).



You may cite our work as
```bibtex
@misc{finlayson2025posttraining,
title={Post-training an LLM for RAG? Train on Self-Generated Demonstrations},
author={Matthew Finlayson and Ilia Kulikov and Daniel M. Bikel and Barlas Oguz and Xilun Chen and Aasish Pappu},
year={2025},
primaryClass={cs.CL},
}
```

## Generating self-demos.

1. Obtain training data. We use the training data from the [RA-DIT paper](https://arxiv.org/abs/2310.01352), placed in directories `data/70b/train/tasks.jsonl` and `data/70b/train/oasst.jsonl` with subsampling weights of 0.9 and 0.1.

2. Generate prompts. Use `scripts/prompt_optimization.py`, e.g.,
```sh
python scripts/prompt_optimization.py \
--dataset_filename "data/70b/tasks.jsonl" \
--model "Meta-Llama-3-70B-Instruct" \
--outfile "data/prompts/base.json" \
--logfile "70B_prompt_optimization.log" \
--eval_example_count 30 \
--train_example_count 30 \
--topk 5 \
--shuffle_window 400 \
--beam_size 12 \
--tensor-parallel-size=2 \
--chat \
--steps 5 \
--rag
```
3. Generate self-demos with `scripts/create_self_demo_train_set.sh`
```sh
bash scripts/create_self_demo_train_set.sh tasks Meta-Llama-3-70B-Instruct
bash scripts/create_self_demo_train_set.sh oasst Meta-Llama-3-70B-Instruct
```

## SFT and DPO training with `fairseq2`

To train a DPO model on self-demonstrations using fairseq2:

```sh
srun fairseq2 lm preference_finetune dpo_checkpoints/fairseq2/self_demo \
--config-file configs/dpo_70b.yml
```

Other configs correspond to SFT and smaller scale (8B) training runs.

Please refer to documentation on the library setup and examples: https://facebookresearch.github.io/fairseq2/stable/

## Evaluation

1. Obtain eval data with retrievals. We use the evals from the [RA-DIT paper](https://arxiv.org/abs/2310.01352), which comes with retrievals and place them in `data/ra-dit/`.
2. Convert eval files to the correct format.
```sh
python scripts/data/io_to_qas_format.py \
data/ra-dit/eli5/eli5-dev-kilt.jsonl \
data/ra-dit/eli5/dev.jsonl
```
3. Run the evaluation.
```sh
judge="Meta-Llama-3.1-405B-Instruct-FP8" # Set to judge model path
eval_set=nq # Set to one of `mmlu zsrequestion conllyagotrunc eli5 hotpotqa nq tqa trex fever wow`
strat=dpo_self_demo_70b # Set to training strategy name
hf_checkpoint= # Set to Huggingface model checkpoint path
pred_tpsize=8 # Tensor parallel size
model_size=70b
ndocs=4
samples=1
preds="data/${strat}/eval/preds/${eval_set}.jsonl"
reward_file="data/${strat}/eval/reward/${eval_set}.jsonl"
reward_file_gemma="data/${strat}/eval/reward/${eval_set}_gemma.jsonl"
response_labels="data/${strat}/eval/response_labels/${eval_set}.jsonl"
response_label_reasons="data/${strat}/eval/response_label_reasons/${eval_set}.jsonl"
relevance="data/relevance/${model_size}/${eval_set}.jsonl"
relevance_reasons="data/relevance_reasons/${model_size}/${eval_set}.jsonl"
resultsfile="results/${strat}/eval/metrics/${eval_set}.json"
datafile=data/ra-dit/${eval_set}/dev.jsonl

# Generate outputs
python scripts/generate.py \
--model=${hf_checkpoint} \
--outfile=$preds \
--samples=$samples \
--tensor-parallel-size=$pred_tpsize \
--ndocs=$ndocs \
--data $datafile

# Get reward model scores
python scripts/reward_model_gemma.py \
--outfile=$reward_file_gemma \
--responses=$preds \
--ndocs=$ndocs \
--data $datafile

# Identify whether context contains the answer
python scripts/relevance.py \
--datafile $datafile \
--reasoning_file $relevance_reasons \
--outfile $relevance \
--ndocs=$ndocs \
--tensor-parallel-size=$tpsize \
--judge=$judge \
--logfile logs/relevance_${eval_set}.log \

# Evaluate (correct/incorrect/refuse) model outputs.
python scripts/eval.py \
--preds $preds \
--datafile $datafile \
--outfile $response_labels \
--reasoning_file $response_label_reasons \
--tensor-parallel-size=$tpsize \
--judge=$judge \
--logfile logs/response_labels_${eval_set}_${strat}.log \
```
71 changes: 71 additions & 0 deletions projects/sd-ra-it/configs/dpo_70b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
model:
_set_:
name: llama3_70b_instruct
dataset:
_set_:
name: sdrait
path: data/70b/train/self_demo/dpo
max_seq_len: 4096
batch_size: 1
gang:
_set_:
tensor_parallel_size: 8
trainer:
fsdp:
_set_:
version: v1
granularity: layer
hsdp: false
reshard_after_forward: true
fp32_reduce: true
_set_:
dtype: bfloat16
data_parallelism: fsdp
mixed_precision: static
gradient_accumulation: 4
activation_checkpointing: true
max_gradient_norm: null
fp16_loss_scale:
- 128.0
- 0.0001
torch_compile: false
profile: null
gradient_check: false
anomaly_detection: false
criterion:
config:
reference_model:
_set_:
name: llama3_70b_instruct
_set_:
reference_dtype: bfloat16
beta: 0.1
nll_scale: 0.0
length_normalization: false
_set_:
name: dpo
optimizer:
config:
_set_:
lr: 5.5e-06
lr_scheduler:
config:
_set_:
cycle_len: null
num_warmup_steps: 0
cycle_mul: 1.0
lr_mul: 1.0
start_lr: 0.0
final_lr: 1.1e-06
final_lr_scale: null
_set_:
name: cosine_annealing
regime:
_set_:
num_steps: 800
num_data_epochs: 5
checkpoint_every_n_steps: 1000
checkpoint_after_n_data_epochs: 1
checkpoint_every_n_data_epochs: null
keep_last_n_checkpoints: 1
publish_metrics_every_n_steps: 5
68 changes: 68 additions & 0 deletions projects/sd-ra-it/configs/dpo_8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
model:
_set_:
name: llama3_8b_instruct
dataset:
_set_:
name: sdrait
path: data/70b/train/self_demo/dpo
max_seq_len: 4096
batch_size: 1
trainer:
fsdp:
_set_:
version: v1
granularity: layer
hsdp: false
reshard_after_forward: true
fp32_reduce: true
_set_:
dtype: bfloat16
data_parallelism: fsdp
mixed_precision: static
gradient_accumulation: 4
activation_checkpointing: true
max_gradient_norm: null
fp16_loss_scale:
- 128.0
- 0.0001
torch_compile: false
profile: null
gradient_check: false
anomaly_detection: false
criterion:
config:
reference_model:
_set_:
name: llama3_8b_instruct
_set_:
reference_dtype: bfloat16
beta: 0.1
nll_scale: 0.0
length_normalization: false
_set_:
name: dpo
optimizer:
config:
_set_:
lr: 5.5e-06
lr_scheduler:
config:
_set_:
cycle_len: null
num_warmup_steps: 0
cycle_mul: 1.0
lr_mul: 1.0
start_lr: 0.0
final_lr: 1.1e-06
final_lr_scale: null
_set_:
name: cosine_annealing
regime:
_set_:
num_steps: 800
num_data_epochs: 5
checkpoint_every_n_steps: 1000
checkpoint_after_n_data_epochs: 1
checkpoint_every_n_data_epochs: null
keep_last_n_checkpoints: 1
publish_metrics_every_n_steps: 5
14 changes: 14 additions & 0 deletions projects/sd-ra-it/configs/sft_70b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
model:
_set_:
name: llama3_70b_instruct
dataset:
_set_:
name: sdrait
path: ra-dit/train
gang:
_set_:
tensor_parallel_size: 8
optimizer:
config:
_set_:
lr: 5.5e-06
11 changes: 11 additions & 0 deletions projects/sd-ra-it/configs/sft_8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
model:
_set_:
name: llama3_8b_instruct
dataset:
_set_:
name: sdrait
path: ra-dit/train
optimizer:
config:
_set_:
lr: 5.5e-06
37 changes: 37 additions & 0 deletions projects/sd-ra-it/scripts/create_self_demo_train_set.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#!/usr/bin/env sh

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

set -e
eval "$(conda shell.bash hook)"

conda activate pytorch
train_set=$1
model_name=$2
if [ $train_set = "oasst" ]
then
num=20_000
elif [ $train_set = "tasks" ]
then
num=200_000
fi
if [ $model_name = "Meta-Llama-3-70B-Instruct" ]
then
tensor_parallel_size=4
else
tensor_parallel_size=1
fi
output_file="data/70b/train/${train_set}.jsonl"
mkdir -p $(dirname $output_file)
python scripts/get_demos.py \
--filename ra-dit/multisource/${train_set}.jsonl \
--output_file $output_file \
--n $num \
--prompts-per-strat 3 \
--tensor_parallel_size $tensor_parallel_size \
--model_name $model_name \
--continued \
--logfile logs/get_demos_${train_set}_70b.log
25 changes: 25 additions & 0 deletions projects/sd-ra-it/scripts/data/io_to_qas_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import json
import os
import sys

from tqdm import tqdm

infilename = sys.argv[1]
outfilename = sys.argv[2]
with open(infilename) as infile, open(outfilename, "w") as outfile:
for line in tqdm(map(json.loads, infile)):
question = line["input"]
answers = [
output.get("answer")
for output in line["output"]
if output.get("answer") is not None
]
output_line = line | dict(question=question, answers=answers)
print(json.dumps(output_line), file=outfile)
Loading
Loading