Skip to content

Commit 76874a8

Browse files
authored
feat: Integrate vlm changes between DTensorPolicyWorker V1 and V2. (#982)
Signed-off-by: Felipe Vieira Frujeri <ffrujeri@nvidia.com>
1 parent f17f331 commit 76874a8

File tree

8 files changed

+293
-94
lines changed

8 files changed

+293
-94
lines changed
Submodule Automodel updated 232 files

examples/configs/recipes/vlm/vlm_grpo-qwen2.5-vl-3b-instruct-clevr-1n2g-dtensor2tp1.v1.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ policy:
4444
precision: "bfloat16"
4545

4646
dtensor_cfg:
47+
_v2: true
4748
enabled: true
4849
cpu_offload: False
4950
sequence_parallel: false

examples/configs/vlm_grpo_3B.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ policy:
4545
precision: "bfloat16"
4646

4747
dtensor_cfg:
48+
_v2: true
4849
enabled: true
4950
cpu_offload: False
5051
sequence_parallel: false

nemo_rl/models/policy/dtensor_policy_worker_v2.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323
import torch
2424
from accelerate import init_empty_weights
2525
from nemo_automodel import (
26-
NeMoAutoModelForCausalLM,
2726
NeMoAutoModelForSequenceClassification,
2827
)
29-
from nemo_automodel.components._transformers.utils import sliding_window_overwrite
28+
from nemo_automodel.components._transformers.utils import (
29+
sliding_window_overwrite,
30+
)
3031
from nemo_automodel.components.distributed.cp_utils import (
3132
create_context_parallel_ctx,
3233
get_train_context,
@@ -56,6 +57,7 @@
5657
from torch.distributed.tensor import DTensor, Shard
5758
from transformers import (
5859
AutoConfig,
60+
AutoProcessor,
5961
AutoTokenizer,
6062
)
6163
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
@@ -79,6 +81,7 @@
7981
get_handle_from_tensor,
8082
get_runtime_env_for_policy_worker,
8183
import_class_from_path,
84+
resolve_model_class,
8285
)
8386
from nemo_rl.utils.native_checkpoint import (
8487
load_checkpoint,
@@ -105,12 +108,19 @@ def __init__(
105108
self,
106109
config: PolicyConfig,
107110
tokenizer: AutoTokenizer,
111+
processor: Optional[AutoProcessor] = None,
108112
weights_path: Optional[str] = None,
109113
optimizer_path: Optional[str] = None,
110114
init_optimizer: bool = True,
111115
init_reference_model: bool = True,
112116
**kwargs: Any,
113117
):
118+
self.tokenizer = tokenizer
119+
self.processor = processor
120+
self.is_vlm = processor is not None
121+
122+
print(f"Initializing DTensorPolicyWorkerV2 with is_vlm={self.is_vlm}")
123+
114124
self.is_generation_colocated = None
115125
if "generation" in config and config["generation"] is not None:
116126
self.is_generation_colocated = config["generation"]["colocated"]["enabled"]
@@ -146,6 +156,9 @@ def __init__(
146156
print(f"[Rank {self.rank}] Loading model {model_name} on CPU...")
147157
self.enable_seq_packing = self.cfg["sequence_packing"]["enabled"]
148158
if self.enable_seq_packing:
159+
assert not self.is_vlm, (
160+
"Sequence packing is not supported for VLM models. Please set policy.sequence_packing.enabled = False to train VLM models."
161+
)
149162
print(
150163
f"[Rank {self.rank}] Sequence packing is enabled for model {model_name}"
151164
)
@@ -195,7 +208,8 @@ def __init__(
195208
else:
196209
raise ValueError(f"Unknown reward model type: {rm_type}")
197210
else:
198-
model_class = NeMoAutoModelForCausalLM
211+
# DO NOT assume AutoModelForCausalLM, multimodal models can inherit from AutoModelForImageTextToText, AutoModelForTextToWaveform, etc.
212+
model_class = resolve_model_class(model_config.model_type)
199213

200214
full_state_dict = None
201215
if self.rank == 0:
@@ -205,6 +219,7 @@ def __init__(
205219
device_map="cpu", # load weights onto CPU initially
206220
trust_remote_code=True,
207221
config=model_config,
222+
torch_dtype=str(model_config.torch_dtype),
208223
)
209224

210225
full_state_dict = model.state_dict()
@@ -224,19 +239,12 @@ def __init__(
224239
if self.enable_seq_packing
225240
else None,
226241
trust_remote_code=True,
242+
torch_dtype=str(model_config.torch_dtype),
227243
)
228244

229245
if self.model.config.pad_token_id is None:
230246
self.model.config.pad_token_id = tokenizer.pad_token_id
231247

232-
# caching since this property is not always preserved after FSDP
233-
self.tokenizer = tokenizer
234-
235-
# ------------------------------------------------
236-
# 3) Move to GPU + Composable FSDP
237-
# (Initialize device mesh, shard submodules, then shard entire model)
238-
# ------------------------------------------------
239-
240248
tp_size = self.cfg["dtensor_cfg"]["tensor_parallel_size"]
241249
cp_size = self.cfg["dtensor_cfg"]["context_parallel_size"]
242250
if cp_size > 1 and self.enable_seq_packing:
@@ -266,6 +274,10 @@ def __init__(
266274
"See https://github.com/NVIDIA-NeMo/RL/issues/659 for more details."
267275
)
268276

277+
assert not self.is_vlm, (
278+
"Context parallel is yet not supported for VLM models. Please set cp_size = 1 to train VLM models."
279+
)
280+
269281
# For FSDP2 compatibility, we need to support HSDP structure
270282
# For now, we use dp_replicate_size = 1 (no hybrid sharding)
271283
dp_replicate_size = 1
@@ -299,6 +311,10 @@ def __init__(
299311
self.cp_size = cp_size
300312
self.device_mesh = device_mesh
301313

314+
# ------------------------------------------------
315+
# 3) Move to GPU + Composable FSDP
316+
# (Initialize device mesh, shard submodules, then shard entire model)
317+
# ------------------------------------------------
302318
self.model = fsdp2_strategy_parallelize(
303319
self.model,
304320
device_mesh=self.device_mesh,
@@ -597,8 +613,18 @@ def train(
597613
).repeat(batch_size, 1)
598614
flash_attn_kwargs = {}
599615

616+
# add vlm kwargs to model call
617+
vlm_kwargs = mb.get_multimodal_dict(
618+
as_tensors=True, device=input_ids.device
619+
)
620+
if len(vlm_kwargs) > 0:
621+
position_ids = None
622+
600623
context_parallel_ctx = None
601624
if self.cp_size > 1:
625+
assert len(vlm_kwargs) == 0, (
626+
f"multimodal kwargs={vlm_kwargs} are not supported for context parallel"
627+
)
602628
seq_index = torch.arange(
603629
seq_len, device=input_ids.device
604630
).repeat(1, 1)
@@ -624,6 +650,7 @@ def train(
624650
position_ids=position_ids,
625651
use_cache=False,
626652
flash_attn_kwargs=flash_attn_kwargs,
653+
**vlm_kwargs,
627654
)
628655

629656
if self._is_reward_model:
@@ -632,6 +659,9 @@ def train(
632659
# is not supported for reward models.
633660
assert not flash_attn_kwargs
634661
del model_args["flash_attn_kwargs"]
662+
# remove flash_attn_kwargs if there are multimodal kwargs
663+
if len(vlm_kwargs) > 0:
664+
del model_args["flash_attn_kwargs"]
635665

636666
outputs = self.model(**model_args)
637667

@@ -859,9 +889,15 @@ def get_logprobs(
859889
step += 1
860890
input_ids = lp_batch.get("input_ids").cuda()
861891
input_lengths = lp_batch.get("input_lengths")
892+
vlm_kwargs = lp_batch.get_multimodal_dict(
893+
as_tensors=True, device=input_ids.device
894+
)
862895

863896
batch_size, seq_len = input_ids.shape
864897
if self.enable_seq_packing:
898+
assert len(vlm_kwargs) == 0, (
899+
"multimodal kwargs are not supported for sequence packing"
900+
)
865901
input_ids, position_ids, _ = pack_sequences(
866902
input_ids=input_ids,
867903
input_lengths=input_lengths,
@@ -901,8 +937,15 @@ def get_logprobs(
901937
(batch_size, seq_len), dtype=torch.long, device=input_ids.device
902938
)
903939

940+
# if there are multimodal kwargs, we don't need to add position_ids (computed internally)
941+
if len(vlm_kwargs) > 0:
942+
position_ids = None
943+
904944
context_parallel_ctx = None
905945
if self.cp_size > 1:
946+
assert len(vlm_kwargs) == 0, (
947+
"multimodal kwargs are not supported for context parallel"
948+
)
906949
seq_index = torch.arange(seq_len, device=input_ids.device).repeat(
907950
1, 1
908951
)
@@ -918,13 +961,18 @@ def get_logprobs(
918961

919962
with get_train_context(False, False, context_parallel_ctx)():
920963
with torch.autocast(device_type="cuda", dtype=self.dtype):
921-
outputs = self.model(
964+
model_args = dict(
922965
input_ids=input_ids,
923966
attention_mask=attention_mask_input_all_ones,
924967
position_ids=position_ids,
925968
use_cache=False,
926969
flash_attn_kwargs=flash_attn_kwargs,
970+
**vlm_kwargs,
927971
)
972+
if len(vlm_kwargs) > 0:
973+
del model_args["flash_attn_kwargs"]
974+
975+
outputs = self.model(**model_args)
928976

929977
logits = outputs.logits
930978

nemo_rl/models/policy/utils.py

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,37 +14,67 @@
1414

1515
import importlib
1616
import os
17-
from collections import defaultdict
18-
from typing import Any
17+
from typing import Any, Dict
1918

2019
import torch
21-
from torch import nn
2220
from transformers import (
2321
AutoConfig,
2422
AutoModelForCausalLM,
2523
AutoModelForImageTextToText,
2624
AutoModelForTextToWaveform,
2725
)
2826

27+
# Try to import nemo_automodel classes, fallback to None if not available
28+
try:
29+
from nemo_automodel.components._transformers.auto_model import (
30+
NeMoAutoModelForCausalLM,
31+
NeMoAutoModelForImageTextToText,
32+
NeMoAutoModelForTextToWaveform,
33+
)
34+
35+
NEMO_AUTOMODEL_AVAILABLE = True
36+
except ImportError:
37+
# nemo_automodel is not installed, classes will be None
38+
NeMoAutoModelForCausalLM = None # type: ignore
39+
NeMoAutoModelForImageTextToText = None # type: ignore
40+
NeMoAutoModelForTextToWaveform = None # type: ignore
41+
NEMO_AUTOMODEL_AVAILABLE = False
42+
2943
from nemo_rl.distributed.worker_group_utils import get_nsight_config_if_pattern_matches
3044

3145
# an automodel factory for loading the huggingface models from correct class
32-
AUTOMODEL_FACTORY = defaultdict(lambda: AutoModelForCausalLM)
33-
AUTOMODEL_FACTORY["qwen2_5_vl"] = AutoModelForImageTextToText
34-
AUTOMODEL_FACTORY["qwen2_vl"] = AutoModelForImageTextToText
35-
AUTOMODEL_FACTORY["qwen2_5_omni"] = AutoModelForTextToWaveform
36-
AUTOMODEL_FACTORY["llava"] = AutoModelForImageTextToText
37-
AUTOMODEL_FACTORY["internvl"] = AutoModelForImageTextToText
38-
AUTOMODEL_FACTORY["gemma3"] = AutoModelForImageTextToText
39-
AUTOMODEL_FACTORY["smolvlm"] = AutoModelForImageTextToText
40-
AUTOMODEL_FACTORY["mistral3"] = AutoModelForImageTextToText
41-
AUTOMODEL_FACTORY["llama4"] = AutoModelForImageTextToText
42-
43-
44-
def resolve_model_class(model_name: str) -> nn.Module:
45-
if model_name.lower() in AUTOMODEL_FACTORY.keys():
46-
return AUTOMODEL_FACTORY[model_name.lower()]
47-
return AutoModelForCausalLM
46+
47+
AUTOMODEL_FACTORY: Dict[str, Any] = {
48+
"qwen2_5_vl": AutoModelForImageTextToText,
49+
"qwen2_vl": AutoModelForImageTextToText,
50+
"qwen2_5_omni": AutoModelForTextToWaveform,
51+
"llava": AutoModelForImageTextToText,
52+
"internvl": AutoModelForImageTextToText,
53+
"gemma3": AutoModelForImageTextToText,
54+
"smolvlm": AutoModelForImageTextToText,
55+
"mistral3": AutoModelForImageTextToText,
56+
"llama4": AutoModelForImageTextToText,
57+
}
58+
59+
if NEMO_AUTOMODEL_AVAILABLE:
60+
AUTOMODEL_FACTORY = {
61+
"qwen2_5_vl": NeMoAutoModelForImageTextToText,
62+
"qwen2_vl": NeMoAutoModelForImageTextToText,
63+
"qwen2_5_omni": NeMoAutoModelForTextToWaveform,
64+
"llava": NeMoAutoModelForImageTextToText,
65+
"internvl": NeMoAutoModelForImageTextToText,
66+
"gemma3": NeMoAutoModelForImageTextToText,
67+
"smolvlm": NeMoAutoModelForImageTextToText,
68+
"mistral3": NeMoAutoModelForImageTextToText,
69+
"llama4": NeMoAutoModelForImageTextToText,
70+
}
71+
72+
73+
def resolve_model_class(model_name: str) -> Any:
74+
"""Resolve the appropriate model class for a given model name."""
75+
if NEMO_AUTOMODEL_AVAILABLE:
76+
return AUTOMODEL_FACTORY.get(model_name.lower(), NeMoAutoModelForCausalLM)
77+
return AUTOMODEL_FACTORY.get(model_name.lower(), AutoModelForCausalLM)
4878

4979

5080
def is_vllm_v1_engine_enabled() -> bool:

pyproject.toml

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,13 +144,16 @@ megatron-bridge = { workspace = true }
144144
nemo_run = { git = "https://github.com/NVIDIA-NeMo/Run", rev = "414f0077c648fde2c71bb1186e97ccbf96d6844c" }
145145
# torch/torchvision/triton all come from the torch index in order to pick up aarch64 wheels
146146
torch = [
147-
{ index = "pytorch-cu128" },
147+
{ index = "pytorch-cu128", marker = "sys_platform != 'darwin'" },
148+
{ index = "pypi", marker = "sys_platform == 'darwin'" },
148149
]
149150
torchvision = [
150-
{ index = "pytorch-cu128" },
151+
{ index = "pytorch-cu128", marker = "sys_platform != 'darwin'" },
152+
{ index = "pypi", marker = "sys_platform == 'darwin'" },
151153
]
152154
triton = [
153-
{ index = "pytorch-cu128" },
155+
{ index = "pytorch-cu128", marker = "sys_platform != 'darwin'" },
156+
{ index = "pypi", marker = "sys_platform == 'darwin'" },
154157
]
155158
causal-conv1d = { git = "https://github.com/Dao-AILab/causal-conv1d", tag = "v1.5.0.post8" }
156159
mamba-ssm = { git = "https://github.com/state-spaces/mamba.git", rev = "2e16fc3062cdcd4ebef27a9aa4442676e1c7edf4" }
@@ -162,6 +165,11 @@ members = [
162165
"3rdparty/Megatron-Bridge-workspace",
163166
]
164167

168+
[[tool.uv.index]]
169+
name = "pypi"
170+
url = "https://pypi.org/simple"
171+
explicit = true
172+
165173
[[tool.uv.index]]
166174
name = "pytorch-cu128"
167175
url = "https://download.pytorch.org/whl/cu128"

pyrefly.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
python-version = "3.12.0"
22
replace-imports-with-any = [
3+
"nemo_automodel.*",
34
"pynvml.*",
45
"hydra._internal.*",
56
"hydra.core.override_parser.*",

0 commit comments

Comments
 (0)