Skip to content

Commit 32f5bef

Browse files
samodi-nvRayenTianjoyang-nv
authored
feat: LoRA SFT support for DTensorV2 path (#1556)
Signed-off-by: Sahil Modi <samodi@nvidia.com> Signed-off-by: ruit <ruit@nvidia.com> Signed-off-by: Jonas Yang <joyang@nvidia.com> Co-authored-by: ruit <ruit@nvidia.com> Co-authored-by: Jonas Yang <joyang@nvidia.com>
1 parent e3cfb11 commit 32f5bef

File tree

12 files changed

+1130
-276
lines changed

12 files changed

+1130
-276
lines changed

docs/guides/sft.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,50 @@ As long as your custom dataset has the `formatted_ds` and `task_spec` attributes
161161
## Evaluate the Trained Model
162162

163163
Upon completion of the training process, you can refer to our [evaluation guide](eval.md) to assess model capabilities.
164+
165+
166+
## LoRA Configuration
167+
168+
NeMo RL supports LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning. LoRA reduces trainable parameters by using low-rank matrices for weight updates while keeping the base model frozen.
169+
170+
Notes:
171+
- LoRA is supported with DTensor v2 and Megatron backends. DTensor v1 does not support LoRA (ensure `policy.dtensor_cfg._v2=true` when using DTensor).
172+
- Triton kernels are only used in the DTensor v2 path. For TP > 1, Automodel currently does not support Triton kernels (see note below).
173+
174+
### Configuration Parameters
175+
176+
The LoRA configuration is specified under the `policy.dtensor_cfg.lora_cfg` section:
177+
178+
policy:
179+
dtensor_cfg:
180+
lora_cfg:
181+
enabled: False # Set to True to enable LoRA fine-tuning
182+
target_modules: [] # List of module names to apply LoRA
183+
exclude_modules: [] # List of module names to exclude from LoRA
184+
match_all_linear: true # Apply LoRA to all linear layers
185+
dim: 8 # LoRA rank (r): controls adaptation capacity
186+
alpha: 32 # LoRA scaling factor (effective lr = alpha/dim)
187+
dropout: 0.0 # Dropout probability for LoRA layers
188+
dropout_position: "post" # Dropout position: "pre" or "post"
189+
lora_A_init: "xavier" # Initialization method: "xavier" or "uniform"
190+
use_triton: true # Use Triton-optimized kernels (DTensor v2 path)
191+
192+
### Parameter Details
193+
- **`enabled`** (bool): Whether to enable LoRA training
194+
- **`target_modules`** (list): Specific module names to apply LoRA. Empty with `match_all_linear=true` applies to all linear layers
195+
- **`exclude_modules`** (list): Module names to exclude from LoRA
196+
- **`match_all_linear`** (bool): When `true`, applies LoRA to all linear layers (overrides `target_modules`)
197+
- **`dim`** (int): LoRA rank (r). Lower values = fewer parameters but less capacity. Typical: 4, 8, 16, 32, 64
198+
- **`alpha`** (int): LoRA scaling factor. Effective learning rate multiplier = `alpha/dim`. Typical: 16, 32, 64
199+
- **`dropout`** (float): Dropout probability for regularization
200+
- **`dropout_position`** (str): Apply dropout before ("pre") or after ("post") LoRA
201+
- **`lora_A_init`** (str): Initialization method for LoRA A matrix
202+
- **`use_triton`** (bool): Use Triton-optimized kernels for better performance. Used for DTensor v2 only. **Note**: [Automodel does not support Triton for TP > 1](https://github.com/NVIDIA-NeMo/Automodel/blob/b2db55eee98dfe81a8bfe5e23ac4e57afd8ab261/nemo_automodel/recipes/llm/train_ft.py#L199). Set to `false` when `tensor_parallel_size > 1` to avoid compatibility issues
203+
204+
### Example Usage
205+
206+
```bash
207+
uv run examples/run_sft.py policy.dtensor_cfg.lora_cfg.enabled=true
208+
```
209+
210+
For more details on LoRA, see [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685).
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
defaults: ../../sft.yaml
2+
sft:
3+
max_num_steps: 350
4+
val_period: 20
5+
val_global_batch_size: 128
6+
val_micro_batch_size: 2
7+
checkpointing:
8+
checkpoint_dir: results/sft-tmblog-llama3.1-8b
9+
save_period: 20
10+
policy:
11+
model_name: meta-llama/Llama-3.1-8B
12+
tokenizer:
13+
name: meta-llama/Llama-3.1-8B-Instruct
14+
chat_template: default
15+
dtensor_cfg:
16+
lora_cfg:
17+
enabled: true
18+
dim: 128
19+
alpha: 128
20+
train_global_batch_size: 128
21+
max_total_sequence_length: 4096
22+
make_sequence_length_divisible_by: 2
23+
optimizer:
24+
kwargs:
25+
lr: 2.0e-05
26+
weight_decay: 0.01
27+
eps: 1.0e-08
28+
data:
29+
dataset_name: tulu3
30+
add_generation_prompt: true
31+
seed: 42
32+
logger:
33+
log_dir: logs/sft-tmblog-llama3.1-8b
34+
tensorboard_enabled: false
35+
wandb:
36+
project: nemo-rl
37+
name: sft-tmblog-llama3.1-8b
38+
tensorboard:
39+
log_dir: tb_logs-sft-dev-tulu3
40+
cluster:
41+
gpus_per_node: 8

examples/configs/sft.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ policy:
3636
offload_optimizer_for_logprob: false
3737

3838
dtensor_cfg:
39+
_v2: true
3940
enabled: true
4041
env_vars: {}
4142
cpu_offload: False
@@ -44,6 +45,19 @@ policy:
4445
tensor_parallel_size: 1
4546
context_parallel_size: 1
4647
custom_parallel_plan: null
48+
49+
# LoRA (Low-Rank Adaptation) Configuration
50+
lora_cfg:
51+
enabled: False # Set to True to enable LoRA fine-tuning
52+
target_modules: [] # List of module names to apply LoRA (empty list with match_all_linear=true applies to all linear layers)
53+
exclude_modules: [] # List of module names to exclude from LoRA
54+
match_all_linear: true # If True, applies LoRA to all linear layers (overrides target_modules)
55+
dim: 8 # LoRA rank (r): lower rank = fewer parameters but less capacity. Typical values: 4, 8, 16, 32, 64
56+
alpha: 32 # LoRA scaling factor: effective learning rate multiplier = alpha/dim. Typical values: 16, 32, 64
57+
dropout: 0.0 # Dropout probability applied to LoRA layers (0.0 = no dropout)
58+
dropout_position: "post" # Where to apply dropout: "pre" (before LoRA) or "post" (after LoRA)
59+
lora_A_init: "xavier" # Initialization method for LoRA A matrix: "xavier" or "uniform"
60+
use_triton: true # Use Triton-optimized kernels for LoRA (faster but requires flash-attn). Disable when tensor_parallel_size > 1
4761

4862
dynamic_batching:
4963
enabled: false

nemo_rl/models/policy/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,23 @@
1717
from nemo_rl.models.generation.interfaces import GenerationConfig
1818

1919

20+
class LoRAConfigDisabled(TypedDict):
21+
enabled: Literal[False]
22+
23+
24+
class LoRAConfig(TypedDict):
25+
enabled: Literal[True]
26+
target_modules: list[str]
27+
exclude_modules: list[str]
28+
match_all_linear: NotRequired[bool]
29+
dim: int
30+
alpha: int
31+
dropout: float
32+
dropout_position: Literal["pre", "post"]
33+
lora_A_init: str
34+
use_triton: NotRequired[bool]
35+
36+
2037
class DTensorConfigDisabled(TypedDict):
2138
enabled: Literal[False]
2239

@@ -32,6 +49,7 @@ class DTensorConfig(TypedDict):
3249
context_parallel_size: int
3350
custom_parallel_plan: str | None
3451
clear_cache_every_n_steps: NotRequired[int | None]
52+
lora_cfg: NotRequired[LoRAConfig | LoRAConfigDisabled]
3553

3654

3755
class SequencePackingConfigDisabled(TypedDict):

nemo_rl/models/policy/lm_policy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def __init__(
112112
if use_v2:
113113
worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2"
114114
else:
115+
assert (
116+
config["dtensor_cfg"].get("lora_cfg", {}).get("enabled", False)
117+
is False
118+
), "LoRA is not supported for DTensorPolicyWorker V1"
115119
worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker"
116120

117121
tp_size = config["dtensor_cfg"]["tensor_parallel_size"]

nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,24 @@
1414

1515
import gc
1616
import itertools
17+
import math
1718
import os
1819
import warnings
1920
from collections import defaultdict
2021
from contextlib import AbstractContextManager, contextmanager, nullcontext
2122
from typing import Any, Generator, Optional, cast
2223

24+
import nemo_automodel.components._peft.lora as _lora_mod
2325
import ray
2426
import torch
2527
from accelerate import init_empty_weights
2628
from nemo_automodel import (
2729
NeMoAutoModelForSequenceClassification,
2830
)
31+
from nemo_automodel.components._peft.lora import (
32+
PeftConfig,
33+
apply_lora_to_linear_modules,
34+
)
2935
from nemo_automodel.components.distributed.cp_utils import (
3036
create_context_parallel_ctx,
3137
get_train_context,
@@ -93,6 +99,15 @@
9399
from nemo_rl.utils.packed_tensor import packed_broadcast_producer
94100

95101

102+
# TODO: @ruit remove this once the bump Automodel to 2d20e33a19d5e53a271b1403b507475e68ad14dc (https://github.com/NVIDIA-NeMo/RL/issues/1586)
103+
def _patched_init_lora_weights(self, init_method: str):
104+
if init_method == "xavier":
105+
nn.init.xavier_normal_(self.lora_A.weight.data)
106+
else:
107+
nn.init.kaiming_uniform_(self.lora_A.weight.data, a=math.sqrt(5))
108+
self.lora_B.weight.data.zero_()
109+
110+
96111
@ray.remote(
97112
runtime_env=get_runtime_env_for_policy_worker("dtensor_policy_worker_v2")
98113
) # pragma: no cover
@@ -222,6 +237,23 @@ def __init__(
222237

223238
full_state_dict = None
224239
model_state_dict_keys = None
240+
241+
# lora config
242+
lora_cfg = self.cfg["dtensor_cfg"].get("lora_cfg", None)
243+
self.peft_config = None
244+
self.lora_enabled = lora_cfg is not None and lora_cfg["enabled"]
245+
# patch the init_lora_weights method to use the xavier initialization
246+
_lora_mod.LinearLoRA.init_lora_weights = _patched_init_lora_weights
247+
if self.lora_enabled:
248+
if self.cfg["dtensor_cfg"]["tensor_parallel_size"] > 1:
249+
assert not lora_cfg["use_triton"], (
250+
"Triton is not supported when tensor_parallel_size > 1"
251+
)
252+
# Always use float32 since FSDP requires all parameters to be in the same dtype.
253+
# autocast should cast the weights to the correct dtype during the forward pass.
254+
cfg_dict_with_dtype = {**lora_cfg, "lora_dtype": "torch.float32"}
255+
self.peft_config = PeftConfig.from_dict(cfg_dict_with_dtype)
256+
225257
if self.rank == 0:
226258
print(f"[Rank {self.rank}] Loading model {model_name} on CPU...")
227259
model = model_class.from_pretrained(
@@ -233,6 +265,9 @@ def __init__(
233265
torch_dtype=str(model_config.torch_dtype),
234266
)
235267

268+
if self.lora_enabled:
269+
apply_lora_to_linear_modules(model, self.peft_config)
270+
236271
full_state_dict = model.state_dict()
237272
# Store the original model state dict keys before any parallelization
238273
model_state_dict_keys = list(full_state_dict.keys())
@@ -255,6 +290,8 @@ def __init__(
255290
trust_remote_code=True,
256291
torch_dtype=str(model_config.torch_dtype),
257292
)
293+
if self.lora_enabled:
294+
apply_lora_to_linear_modules(self.model, self.peft_config)
258295

259296
if self.model.config.pad_token_id is None:
260297
self.model.config.pad_token_id = tokenizer.pad_token_id
@@ -1857,6 +1894,9 @@ def save_checkpoint(
18571894
"peft_config",
18581895
}
18591896
}
1897+
if self.lora_enabled:
1898+
checkpoint_kwargs["is_peft"] = True
1899+
checkpoint_kwargs["peft_config"] = self.peft_config
18601900

18611901
save_checkpoint(
18621902
model=self.model,
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/bin/bash
2+
3+
# clean up checkpoint directory on exit
4+
trap "rm -rf /tmp/lora_sft_checkpoints" EXIT
5+
6+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
7+
PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
8+
# Mark the current repo as safe, since wandb fetches metadata about the repo
9+
git config --global --add safe.directory $PROJECT_ROOT
10+
11+
set -eou pipefail
12+
13+
EXP_NAME=$(basename $0 .sh)
14+
EXP_DIR=$SCRIPT_DIR/$EXP_NAME
15+
LOG_DIR=$EXP_DIR/logs
16+
JSON_METRICS=$EXP_DIR/metrics.json
17+
RUN_LOG=$EXP_DIR/run.log
18+
export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}
19+
20+
rm -rf $EXP_DIR $LOG_DIR
21+
mkdir -p $EXP_DIR $LOG_DIR
22+
23+
cd $PROJECT_ROOT
24+
uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
25+
$PROJECT_ROOT/examples/run_sft.py \
26+
policy.model_name=Qwen/Qwen3-0.6B \
27+
cluster.gpus_per_node=2 \
28+
sft.max_num_steps=3 \
29+
sft.val_batches=1 \
30+
sft.val_period=3 \
31+
policy.dtensor_cfg.lora.enabled=true \
32+
logger.tensorboard_enabled=true \
33+
logger.log_dir=$LOG_DIR \
34+
logger.wandb_enabled=false \
35+
logger.monitor_gpus=true \
36+
checkpointing.enabled=true \
37+
checkpointing.save_period=3 \
38+
checkpointing.checkpoint_dir=/tmp/lora_sft_checkpoints \
39+
"$@" \
40+
2>&1 | tee $RUN_LOG
41+
42+
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
43+
44+
uv run tests/check_metrics.py $JSON_METRICS \
45+
'data["train/loss"]["3"] < 5.9'
46+
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#!/bin/bash
2+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
3+
source $SCRIPT_DIR/common.env
4+
5+
# ===== BEGIN CONFIG =====
6+
NUM_NODES=1
7+
STEPS_PER_RUN=50
8+
MAX_STEPS=50
9+
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
10+
NUM_MINUTES=30
11+
# ===== END CONFIG =====
12+
13+
exit_if_max_steps_reached
14+
15+
# Run the experiment
16+
cd $PROJECT_ROOT
17+
uv run examples/run_sft.py \
18+
--config $CONFIG_PATH \
19+
sft.max_num_steps=$MAX_STEPS \
20+
logger.log_dir=$LOG_DIR \
21+
logger.wandb_enabled=True \
22+
logger.wandb.project=ruit_personal_debug \
23+
logger.wandb.name=$EXP_NAME \
24+
logger.monitor_gpus=True \
25+
logger.tensorboard_enabled=True \
26+
checkpointing.enabled=True \
27+
checkpointing.checkpoint_dir=$CKPT_DIR \
28+
$@ \
29+
2>&1 | tee $RUN_LOG
30+
31+
# Convert tensorboard logs to json
32+
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
33+
34+
# TODO: memory check will fail due to OOM tracked here https://github.com/NVIDIA-NeMo/RL/issues/263
35+
36+
# Only run metrics if the target step is reached
37+
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
38+
uv run tests/check_metrics.py $JSON_METRICS \
39+
'data["train/loss"]["1"] < 1.0' \
40+
'data["train/loss"]["50"] < 0.8' \
41+
'max(data["ray/node.0.gpu.0.mem_gb"]) < 50' \
42+
'mean(data["timing/train/total_step_time"], 2) < 10'
43+
fi

tests/test_suites/nightly.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ tests/test_suites/llm/sft-llama3.2-1b-1n8g-fsdp2tp1.v3.sh
6666
tests/test_suites/llm/sft-llama3.1-8b-1n8g-fsdp2tp2.sh
6767
# dynamic batching
6868
tests/test_suites/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-dynamicbatch.sh
69+
# lora
70+
# Tulu3 dataset is not supported yet. Re-enable this test once PR https://github.com/NVIDIA-NeMo/RL/pull/1506 is merged.
71+
# tests/test_suites/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-lora.sh
6972

7073
# Functional 32b test
7174
tests/test_suites/llm/sft-qwen2.5-32b-4n8g-fsdp2tp8sp-actckpt.v3.sh

0 commit comments

Comments
 (0)