Skip to content

Commit dacac7e

Browse files
authored
feat: Support lora in dtensor grpo workflow by merging weight (#1797)
Signed-off-by: ruit <[email protected]>
1 parent 1f55e25 commit dacac7e

File tree

10 files changed

+362
-17
lines changed

10 files changed

+362
-17
lines changed

examples/configs/grpo_math_1B.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,19 @@ policy:
9090
tensor_parallel_size: 1
9191
context_parallel_size: 1
9292
custom_parallel_plan: null
93+
94+
# LoRA (Low-Rank Adaptation) Configuration
95+
lora_cfg:
96+
enabled: False # Set to True to enable LoRA fine-tuning
97+
target_modules: [] # List of module names to apply LoRA (empty list with match_all_linear=true applies to all linear layers)
98+
exclude_modules: [] # List of module names to exclude from LoRA
99+
match_all_linear: true # If True, applies LoRA to all linear layers (overrides target_modules)
100+
dim: 8 # LoRA rank (r): lower rank = fewer parameters but less capacity. Typical values: 4, 8, 16, 32, 64
101+
alpha: 32 # LoRA scaling factor: effective learning rate multiplier = alpha/dim. Typical values: 16, 32, 64
102+
dropout: 0.0 # Dropout probability applied to LoRA layers (0.0 = no dropout)
103+
dropout_position: "post" # Where to apply dropout: "pre" (before LoRA) or "post" (after LoRA)
104+
lora_A_init: "xavier" # Initialization method for LoRA A matrix: "xavier" or "uniform"
105+
use_triton: true # Use Triton-optimized kernels for LoRA (faster but requires flash-attn). Disable when tensor_parallel_size > 1
93106

94107
megatron_cfg:
95108
enabled: false
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
defaults: ../../grpo_math_1B.yaml
2+
grpo:
3+
val_at_start: true
4+
checkpointing:
5+
checkpoint_dir: results/grpo-qwen3-8B-base-1n8g-fsdp2-lora
6+
policy:
7+
model_name: Qwen/Qwen3-8B-Base
8+
max_total_sequence_length: 2048
9+
dtensor_cfg:
10+
activation_checkpointing: true
11+
lora_cfg:
12+
enabled: True
13+
dim: 128
14+
alpha: 128
15+
sequence_packing:
16+
enabled: false
17+
logger:
18+
log_dir: logs/grpo-qwen3-8B-base-1n8g-fsdp2-lora
19+
wandb_enabled: true
20+
tensorboard_enabled: true
21+
wandb:
22+
project: nemo-rl
23+
name: grpo-qwen3-8B-base-1n8g-fsdp2-lora
24+
cluster:
25+
gpus_per_node: 8

nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import ray
2323
import torch
24+
from nemo_automodel.components._peft.lora import LinearLoRA
2425
from nemo_automodel.components.distributed.cp_utils import (
2526
create_context_parallel_ctx,
2627
)
@@ -85,19 +86,66 @@ def dtensor_params_generator(
8586
Args:
8687
model: The model whose parameters to generate.
8788
target_dtype: The dtype to convert tensors to.
89+
peft_config: Optional LoRA config for filtering which layers to merge.
8890
8991
Yields:
9092
Tuples of (fully_qualified_name, tensor) where tensors are converted to target dtype and made contiguous.
9193
"""
94+
module_map = dict(model.named_modules())
9295
for name, tensor in model.state_dict().items():
96+
if name.endswith(".lora_A.weight") or name.endswith(".lora_B.weight"):
97+
continue
9398
full_tensor = tensor.full_tensor() if isinstance(tensor, DTensor) else tensor
94-
adapted_fqn_tensors = _maybe_adapt_tensor_to_hf(model, name, full_tensor)
99+
merged_tensor = _maybe_merge_lora_weight(module_map, name, full_tensor)
100+
101+
adapted_fqn_tensors = _maybe_adapt_tensor_to_hf(model, name, merged_tensor)
95102
for adapted_fqn, adapted_tensor in adapted_fqn_tensors:
96103
# Convert to target dtype
97104
yield (
98105
adapted_fqn,
99106
adapted_tensor.to(target_dtype, non_blocking=True).contiguous(),
100107
)
108+
del adapted_tensor
109+
del adapted_fqn_tensors
110+
del merged_tensor
111+
del full_tensor
112+
113+
114+
@torch.no_grad()
115+
def _maybe_merge_lora_weight(
116+
module_map: dict[str, nn.Module],
117+
fqn: str,
118+
tensor: torch.Tensor,
119+
) -> torch.Tensor:
120+
if not fqn.endswith(".weight"):
121+
return tensor
122+
module_name = fqn[: -len(".weight")]
123+
module = module_map.get(module_name)
124+
if not isinstance(module, LinearLoRA):
125+
return tensor
126+
if not (hasattr(module, "lora_A") and hasattr(module, "lora_B")):
127+
return tensor
128+
129+
lora_a = (
130+
module.lora_A.weight.full_tensor()
131+
if isinstance(module.lora_A.weight, DTensor)
132+
else module.lora_A.weight
133+
)
134+
lora_b = (
135+
module.lora_B.weight.full_tensor()
136+
if isinstance(module.lora_B.weight, DTensor)
137+
else module.lora_B.weight
138+
)
139+
lora_a = lora_a.to(device=tensor.device, dtype=tensor.dtype)
140+
lora_b = lora_b.to(device=tensor.device, dtype=tensor.dtype)
141+
scale = getattr(module, "scale", None)
142+
143+
if scale is None and hasattr(module, "alpha") and hasattr(module, "dim"):
144+
scale = module.alpha / module.dim
145+
if scale is None:
146+
scale = 1.0
147+
148+
return tensor + torch.matmul(lora_b, lora_a) * scale
101149

102150

103151
def _maybe_adapt_tensor_to_hf(
@@ -1208,6 +1256,8 @@ def prepare_refit_info(self) -> Optional[dict[str, Any]]:
12081256
"""Prepare state dict metadata for weight refitting and IPC streaming."""
12091257
state_dict_info = {}
12101258
for name, tensor in self.model.state_dict().items():
1259+
if name.endswith(".lora_A.weight") or name.endswith(".lora_B.weight"):
1260+
continue
12111261
full_tensor = (
12121262
tensor.full_tensor() if isinstance(tensor, DTensor) else tensor
12131263
)

tests/functional/L1_Functional_Tests_GPU.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ time uv run --no-sync bash ./tests/functional/sft.sh
2727
time uv run --no-sync bash ./tests/functional/sft_resume_diamond.sh
2828
time uv run --no-sync bash ./tests/functional/grpo.sh
2929
time uv run --no-sync bash ./tests/functional/grpo_async.sh
30+
time uv run --no-sync bash ./tests/functional/grpo_automodel_lora_async.sh
31+
time uv run --no-sync bash ./tests/functional/grpo_automodel_lora_non_colocated.sh
32+
time uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh
3033
time uv run --no-sync bash ./tests/functional/grpo_megatron.sh
3134
time uv run --no-sync bash ./tests/functional/grpo_megatron_generation.sh
3235
time uv run --no-sync bash ./tests/functional/grpo_multiturn.sh
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/bin/bash
2+
3+
# clean up checkpoint directory on exit
4+
trap "rm -rf /tmp/lora_sft_checkpoints" EXIT
5+
6+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
7+
PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
8+
# Mark the current repo as safe, since wandb fetches metadata about the repo
9+
git config --global --add safe.directory $PROJECT_ROOT
10+
11+
set -eou pipefail
12+
13+
EXP_NAME=$(basename $0 .sh)
14+
EXP_DIR=$SCRIPT_DIR/$EXP_NAME
15+
LOG_DIR=$EXP_DIR/logs
16+
JSON_METRICS=$EXP_DIR/metrics.json
17+
RUN_LOG=$EXP_DIR/run.log
18+
export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}
19+
20+
rm -rf $EXP_DIR $LOG_DIR
21+
mkdir -p $EXP_DIR $LOG_DIR
22+
23+
cd $PROJECT_ROOT
24+
uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
25+
$PROJECT_ROOT/examples/run_grpo_math.py\
26+
grpo.max_num_steps=3 \
27+
grpo.num_prompts_per_step=8 \
28+
grpo.num_generations_per_prompt=4 \
29+
data.shuffle=false \
30+
policy.dtensor_cfg.lora_cfg.enabled=True \
31+
policy.dtensor_cfg.lora_cfg.dim=32 \
32+
policy.train_global_batch_size=32 \
33+
policy.train_micro_batch_size=1 \
34+
cluster.gpus_per_node=2 \
35+
logger.tensorboard_enabled=true \
36+
logger.log_dir=$LOG_DIR \
37+
logger.wandb_enabled=false \
38+
logger.monitor_gpus=true \
39+
checkpointing.enabled=false \
40+
"$@" \
41+
2>&1 | tee $RUN_LOG
42+
43+
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
44+
45+
uv run tests/check_metrics.py $JSON_METRICS \
46+
'max(data["train/reward"]) > 0.03'
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#!/bin/bash
2+
3+
# clean up checkpoint directory on exit
4+
trap "rm -rf /tmp/lora_sft_checkpoints" EXIT
5+
6+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
7+
PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
8+
# Mark the current repo as safe, since wandb fetches metadata about the repo
9+
git config --global --add safe.directory $PROJECT_ROOT
10+
11+
set -eou pipefail
12+
13+
EXP_NAME=$(basename $0 .sh)
14+
EXP_DIR=$SCRIPT_DIR/$EXP_NAME
15+
LOG_DIR=$EXP_DIR/logs
16+
JSON_METRICS=$EXP_DIR/metrics.json
17+
RUN_LOG=$EXP_DIR/run.log
18+
export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}
19+
20+
rm -rf $EXP_DIR $LOG_DIR
21+
mkdir -p $EXP_DIR $LOG_DIR
22+
23+
cd $PROJECT_ROOT
24+
NRL_FORCE_REBUILD_VENVS=true uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
25+
$PROJECT_ROOT/examples/run_grpo_math.py\
26+
grpo.max_num_steps=3 \
27+
grpo.num_prompts_per_step=8 \
28+
grpo.num_generations_per_prompt=4 \
29+
data.shuffle=false \
30+
policy.dtensor_cfg.lora_cfg.enabled=True \
31+
policy.dtensor_cfg.lora_cfg.dim=32 \
32+
policy.train_global_batch_size=32 \
33+
policy.train_micro_batch_size=1 \
34+
policy.generation.colocated.enabled=false \
35+
policy.generation.colocated.resources.gpus_per_node=1 \
36+
policy.generation.colocated.resources.num_nodes=1 \
37+
policy.generation.vllm_cfg.async_engine=true \
38+
grpo.async_grpo.enabled=true \
39+
loss_fn.use_importance_sampling_correction=true \
40+
cluster.gpus_per_node=2 \
41+
logger.tensorboard_enabled=true \
42+
logger.log_dir=$LOG_DIR \
43+
logger.wandb_enabled=false \
44+
logger.monitor_gpus=true \
45+
checkpointing.enabled=false \
46+
"$@" \
47+
2>&1 | tee $RUN_LOG
48+
49+
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
50+
51+
uv run tests/check_metrics.py $JSON_METRICS \
52+
'max(data["train/reward"]) > 0.03'
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#!/bin/bash
2+
3+
# clean up checkpoint directory on exit
4+
trap "rm -rf /tmp/lora_sft_checkpoints" EXIT
5+
6+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
7+
PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
8+
# Mark the current repo as safe, since wandb fetches metadata about the repo
9+
git config --global --add safe.directory $PROJECT_ROOT
10+
11+
set -eou pipefail
12+
13+
EXP_NAME=$(basename $0 .sh)
14+
EXP_DIR=$SCRIPT_DIR/$EXP_NAME
15+
LOG_DIR=$EXP_DIR/logs
16+
JSON_METRICS=$EXP_DIR/metrics.json
17+
RUN_LOG=$EXP_DIR/run.log
18+
export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}
19+
20+
rm -rf $EXP_DIR $LOG_DIR
21+
mkdir -p $EXP_DIR $LOG_DIR
22+
23+
cd $PROJECT_ROOT
24+
NRL_FORCE_REBUILD_VENVS=true uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
25+
$PROJECT_ROOT/examples/run_grpo_math.py\
26+
grpo.max_num_steps=3 \
27+
grpo.num_prompts_per_step=8 \
28+
grpo.num_generations_per_prompt=4 \
29+
data.shuffle=false \
30+
policy.dtensor_cfg.lora_cfg.enabled=True \
31+
policy.dtensor_cfg.lora_cfg.dim=32 \
32+
policy.train_global_batch_size=32 \
33+
policy.train_micro_batch_size=1 \
34+
policy.generation.colocated.enabled=false \
35+
policy.generation.colocated.resources.gpus_per_node=1 \
36+
policy.generation.colocated.resources.num_nodes=1 \
37+
cluster.gpus_per_node=2 \
38+
logger.tensorboard_enabled=true \
39+
logger.log_dir=$LOG_DIR \
40+
logger.wandb_enabled=false \
41+
logger.monitor_gpus=true \
42+
checkpointing.enabled=false \
43+
"$@" \
44+
2>&1 | tee $RUN_LOG
45+
46+
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
47+
48+
uv run tests/check_metrics.py $JSON_METRICS \
49+
'max(data["train/reward"]) > 0.03'
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#!/bin/bash
2+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
3+
source $SCRIPT_DIR/common.env
4+
5+
# ===== BEGIN CONFIG =====
6+
NUM_NODES=1
7+
STEPS_PER_RUN=20
8+
MAX_STEPS=20
9+
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
10+
NUM_MINUTES=30
11+
# ===== END CONFIG =====
12+
13+
exit_if_max_steps_reached
14+
15+
# Run the experiment
16+
cd $PROJECT_ROOT
17+
uv run examples/run_grpo_math.py \
18+
--config $CONFIG_PATH \
19+
grpo.max_num_steps=$MAX_STEPS \
20+
logger.log_dir=$LOG_DIR \
21+
logger.wandb_enabled=True \
22+
logger.wandb.project=nemo-rl \
23+
logger.wandb.name=$EXP_NAME \
24+
logger.monitor_gpus=True \
25+
logger.tensorboard_enabled=True \
26+
checkpointing.enabled=True \
27+
checkpointing.checkpoint_dir=$CKPT_DIR \
28+
$@ \
29+
2>&1 | tee $RUN_LOG
30+
31+
# Convert tensorboard logs to json
32+
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
33+
34+
# Only run metrics if the target step is reached
35+
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
36+
uv run tests/check_metrics.py $JSON_METRICS \
37+
'mean(data["train/gen_kl_error"], 20) < 0.002' \
38+
'data["train/gen_kl_error"]["20"] < 0.002' \
39+
'max(data["train/reward"]) > 0.35' \
40+
'mean(data["timing/train/total_step_time"], 2) < 80'
41+
42+
# Clean up checkpoint directory after successful run to save space.
43+
rm -rf "$CKPT_DIR"
44+
fi

tests/test_suites/nightly.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ tests/test_suites/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.sh
6666
tests/test_suites/llm/grpo-nano-v2-12b-1n8g-megatron.sh
6767
tests/test_suites/llm/grpo-nano-v2-12b-2n8g-fsdp2tp1.sh
6868

69+
# lora
70+
tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh
71+
6972
#######
7073
# SFT #
7174
#######

0 commit comments

Comments
 (0)