Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions sota-implementations/expert-iteration/ei_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

from torchrl._utils import logger as torchrl_logger
from torchrl.envs.llm import RetrieveLogProb
from torchrl.envs.llm.datasets.countdown import CountdownEnv
from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
from torchrl.envs.llm.datasets.math import MATHEnv
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
from torchrl.weight_update.llm import VLLMWeightSyncScheme
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
Expand Down Expand Up @@ -63,22 +65,24 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None):
ref_model = get_ref_model(ref_cfg, train_tokenizer, devices=devices)

# Setup environment
common_kwargs = {
"repeats": cfg.env.repeats,
"tokenizer": train_tokenizer,
"num_envs": cfg.env.num_envs,
"device": torch.device("cpu"),
}
if cfg.env.dataset == "gsm8k":
from torchrl.envs.llm import GSM8KEnv

env = GSM8KEnv(
repeats=cfg.env.repeats,
tokenizer=train_tokenizer,
num_envs=cfg.env.num_envs,
device=torch.device("cpu"),
)
else: # ifeval
env = IFEvalEnv(
repeats=cfg.env.repeats,
tokenizer=train_tokenizer,
num_envs=cfg.env.num_envs,
device=torch.device("cpu"),
)
env = GSM8KEnv(**common_kwargs)
elif cfg.env.dataset == "ifeval":
env = IFEvalEnv(**common_kwargs)
elif cfg.env.dataset == "math":
env = MATHEnv(**common_kwargs)
elif cfg.env.dataset == "countdown":
env = CountdownEnv(**common_kwargs)
else:
raise NotImplementedError(f"Dataset {cfg.env.dataset} not implemented")

# Pass device directly to RetrieveLogProb - Since, for Ray, the local device is always 0
# we can just use 0 here.
Expand Down
111 changes: 111 additions & 0 deletions sota-implementations/grpo/config/grpo_countdown.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# @package _global_
defaults:
- mode: ${mode:async}
- _self_
- override hydra/hydra_logging: disabled
- override hydra/job_logging: disabled

env:
dataset: countdown
num_envs: 32
repeats: 16
reasoning: false
max_steps: 2

model:
name: Qwen/Qwen2.5-3B
compile: false

train:
exp_name: "grpo-countdown"
mixed_precision: true
total_dialog_turns: 100_000
packing: false
dialog_turns_per_batch: 32
gradient_accumulation_steps: 8
checkpoint_frequency: 100
optim_batch_size: 32
kl_coef_in_loss: true
use_kl_to_ref: false
kl_to_ref_coeff: 0.0
kl_to_inference_coeff: 1e-2
entropy_coeff: 1e-4
logging_frequency: 10
empty_replay_buffer: true

train_model:
gradient_checkpointing: true
num_devices: 1
lora:
enabled: true
r: 8
alpha: 16
dropout: 0.1
quantization:
enabled: false
attn_implementation: sdpa
torch_dtype: bfloat16

inference_model:
num_devices: 1
quantization:
enabled: false
attn_implementation: sdpa
torch_dtype: bfloat16
gpu_memory_utilization: 0.9
temperature: 1.0
top_p: 0.95
max_tokens: 512
include_stop_str_in_output: true
enforce_eager: false

ref_model:
gradient_checkpointing: false
num_devices: 1
lora:
enabled: true
r: 8
alpha: 16
dropout: 0.1
quantization:
enabled: false
attn_implementation: sdpa
torch_dtype: bfloat16

optimizer:
name: AdamW
lr: 1e-5
clip_grad_norm: 1.0
weight_decay: 0.0

ray:
init_config:
num_cpus: 96
num_gpus: 8
runtime_env:
working_dir: "."
_temp_dir: "/tmp/ray_grpo"
_system_config:
object_spilling_threshold: 0.8
max_direct_memory_size: 10 * 1024 * 1024 * 1024
object_store_full_delay_ms: 100
object_store_full_max_retries: 3
collector_config:
num_cpus: 4
train_handler_config:
num_cpus: 4
replay_buffer_config:
num_cpus: 4
num_gpus: 0.0

logging:
experiment_name: null
checkpoint_dir: "checkpoints"
checkpoint_frequency: 10

hydra:
run:
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
sweep:
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
subdir: ${hydra.job.num}
2 changes: 1 addition & 1 deletion sota-implementations/grpo/config/grpo_gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ defaults:

# Environment configuration
env:
dataset: gsm8k # choices: [gsm8k, ifeval]
dataset: gsm8k # choices: [gsm8k, ifeval, math, countdown]
# Number of environments to run in parallel. This determines the batch size passed to vLLM.
# More envs do not consume more GPU memory but there will be a sync on the call to vLLM.
num_envs: 32
Expand Down
111 changes: 111 additions & 0 deletions sota-implementations/grpo/config/grpo_math.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# @package _global_
defaults:
- mode: ${mode:async}
- _self_
- override hydra/hydra_logging: disabled
- override hydra/job_logging: disabled

env:
dataset: math
num_envs: 32
repeats: 16
reasoning: false
max_steps: 2

model:
name: Qwen/Qwen2.5-3B
compile: false

train:
exp_name: "grpo-math"
mixed_precision: true
total_dialog_turns: 100_000
packing: false
dialog_turns_per_batch: 32
gradient_accumulation_steps: 8
checkpoint_frequency: 100
optim_batch_size: 32
kl_coef_in_loss: true
use_kl_to_ref: false
kl_to_ref_coeff: 0.0
kl_to_inference_coeff: 1e-2
entropy_coeff: 1e-4
logging_frequency: 10
empty_replay_buffer: true

train_model:
gradient_checkpointing: true
num_devices: 1
lora:
enabled: true
r: 8
alpha: 16
dropout: 0.1
quantization:
enabled: false
attn_implementation: sdpa
torch_dtype: bfloat16

inference_model:
num_devices: 1
quantization:
enabled: false
attn_implementation: sdpa
torch_dtype: bfloat16
gpu_memory_utilization: 0.9
temperature: 1.0
top_p: 0.95
max_tokens: 1024
include_stop_str_in_output: true
enforce_eager: false

ref_model:
gradient_checkpointing: false
num_devices: 1
lora:
enabled: true
r: 8
alpha: 16
dropout: 0.1
quantization:
enabled: false
attn_implementation: sdpa
torch_dtype: bfloat16

optimizer:
name: AdamW
lr: 1e-5
clip_grad_norm: 1.0
weight_decay: 0.0

ray:
init_config:
num_cpus: 96
num_gpus: 8
runtime_env:
working_dir: "."
_temp_dir: "/tmp/ray_grpo"
_system_config:
object_spilling_threshold: 0.8
max_direct_memory_size: 10 * 1024 * 1024 * 1024
object_store_full_delay_ms: 100
object_store_full_max_retries: 3
collector_config:
num_cpus: 4
train_handler_config:
num_cpus: 4
replay_buffer_config:
num_cpus: 4
num_gpus: 0.0

logging:
experiment_name: null
checkpoint_dir: "checkpoints"
checkpoint_frequency: 10

hydra:
run:
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
sweep:
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
subdir: ${hydra.job.num}
37 changes: 19 additions & 18 deletions sota-implementations/grpo/grpo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

from torchrl._utils import logger as torchrl_logger, timeit
from torchrl.envs.llm import AddThinkingPrompt, GSM8KEnv, KLRewardTransform, RetrieveKL
from torchrl.envs.llm.datasets.countdown import CountdownEnv
from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
from torchrl.envs.llm.datasets.math import MATHEnv
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
from torchrl.weight_update.llm import VLLMWeightSyncScheme
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
Expand Down Expand Up @@ -648,28 +650,27 @@ def make_env(cfg: DictConfig, single_env: bool = False):

# Setup environment
max_steps = cfg.env.max_steps if cfg.env.reasoning else 1
num_envs = cfg.env.num_envs if not single_env else 1
common_kwargs = {
"repeats": cfg.env.repeats,
"tokenizer": train_tokenizer,
"num_envs": num_envs,
"max_steps": max_steps,
"device": torch.device("cpu"),
}

if cfg.env.dataset == "gsm8k":
# Reward scale is 0.0 to 1.0
reward_threshold = 0.1
env = GSM8KEnv(
repeats=cfg.env.repeats,
tokenizer=train_tokenizer,
num_envs=cfg.env.num_envs if not single_env else 1,
max_steps=max_steps,
device=torch.device("cpu"),
ray_backend=True,
)
env = GSM8KEnv(**common_kwargs, ray_backend=True)
elif cfg.env.dataset == "ifeval":
# Reward scale is 0.0 to ~1.15
reward_threshold = 0.5
env = IFEvalEnv(
repeats=cfg.env.repeats,
tokenizer=train_tokenizer,
num_envs=cfg.env.num_envs if not single_env else 1,
max_steps=max_steps,
device=torch.device("cpu"),
ray_backend=True,
)
env = IFEvalEnv(**common_kwargs, ray_backend=True)
elif cfg.env.dataset == "math":
reward_threshold = 0.1
env = MATHEnv(**common_kwargs, ray_backend=True)
elif cfg.env.dataset == "countdown":
reward_threshold = 0.1
env = CountdownEnv(**common_kwargs)
else:
raise NotImplementedError(f"Dataset {cfg.env.dataset} not implemented")

Expand Down
Loading