Skip to content

Commit 223bfa8

Browse files
gshennvmterrykong
andauthored
feat: add nemotron5 sharding (#481)
Signed-off-by: Terry Kong <terryk@nvidia.com> Co-authored-by: Terry Kong <terryk@nvidia.com>
1 parent 18b9e2c commit 223bfa8

File tree

12 files changed

+296
-5
lines changed

12 files changed

+296
-5
lines changed

docs/testing.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@ This guide outlines how to test NeMo RL using unit and functional tests, detaili
44

55
## Unit Tests
66

7-
:::{important}
8-
Unit tests require 2 GPUs to test the full suite.
9-
:::
7+
> [!IMPORTANT]
8+
> Unit tests require 2 GPUs to test the full suite.
9+
10+
> [!TIP]
11+
> Some unit tests require setting up test assets which you can download with
12+
> ```sh
13+
> uv run tests/unit/prepare_unit_test_assets.py
14+
> ```
15+
1016
1117
```sh
1218
# Run the unit tests using local GPUs

nemo_rl/distributed/virtual_cluster.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ class PY_EXECUTABLES:
4545
# Use NeMo-RL direct dependencies.
4646
BASE = "uv run --locked"
4747

48+
# Use NeMo-RL direct dependencies.
49+
AUTOMODEL = "uv run --locked --extra automodel"
50+
4851
# Use NeMo-RL direct dependencies and vllm.
4952
VLLM = "uv run --locked --extra vllm"
5053

nemo_rl/models/dtensor/parallelize.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,76 @@ def get_hf_tp_plan(model: PreTrainedModel):
357357
return hf_tp_plan
358358

359359

360+
def _parallelize_nm5_h(
361+
model,
362+
dp_mesh: DeviceMesh,
363+
tp_mesh: DeviceMesh,
364+
param_dtype: torch.dtype,
365+
sequence_parallel: bool = False,
366+
activation_checkpointing: bool = False,
367+
cpu_offload: bool = False,
368+
custom_parallel_plan: Optional[Union[dict, str]] = None,
369+
) -> torch.distributed.fsdp.FSDPModule:
370+
"""Parallelize a NemotronHForCausalLM model across data and tensor parallel dimensions."""
371+
assert not sequence_parallel, (
372+
"Sequence parallelism is not supported for NemotronHForCausalLM"
373+
)
374+
assert custom_parallel_plan is None, (
375+
"Custom parallel plan is not supported for NemotronHForCausalLM"
376+
)
377+
378+
model_tp_plan: dict[str, ParallelStyle] = {
379+
"lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False),
380+
}
381+
382+
mlp_tp_plan: dict[str, ParallelStyle] = {
383+
"mixer.up_proj": ColwiseParallel(),
384+
"mixer.down_proj": RowwiseParallel(),
385+
}
386+
387+
layers: torch.nn.ModuleList = model.backbone.layers
388+
parallelize_module(model, tp_mesh, model_tp_plan)
389+
390+
for layer in model.backbone.layers:
391+
if layer.block_type == "mlp":
392+
parallelize_module(layer, tp_mesh, mlp_tp_plan)
393+
394+
if activation_checkpointing:
395+
for i in range(len(layers)):
396+
if layers[i].block_type == "mlp":
397+
layers[i] = checkpoint_wrapper(layers[i])
398+
399+
if layers[i].block_type == "mamba":
400+
layers[i] = checkpoint_wrapper(layers[i])
401+
402+
mp_policy = MixedPrecisionPolicy(
403+
param_dtype=param_dtype,
404+
reduce_dtype=torch.float32,
405+
output_dtype=torch.float32,
406+
)
407+
408+
offload_policy = (
409+
CPUOffloadPolicy(pin_memory=False)
410+
if cpu_offload
411+
else torch.distributed.fsdp.OffloadPolicy
412+
)
413+
414+
for layer in layers:
415+
fully_shard(
416+
layer, mesh=dp_mesh, mp_policy=mp_policy, offload_policy=offload_policy
417+
)
418+
419+
# do not reshard after forward for root model
420+
# because its parameters will be used in backward immediately
421+
return fully_shard(
422+
model,
423+
mesh=dp_mesh,
424+
mp_policy=mp_policy,
425+
offload_policy=offload_policy,
426+
reshard_after_forward=False,
427+
)
428+
429+
360430
def _parallelize_model(
361431
model: Union[
362432
Qwen2ForCausalLM,
@@ -394,7 +464,20 @@ def _parallelize_model(
394464
ValueError: If the model type is not supported for parallelization.
395465
"""
396466
model_cls = type(model)
397-
if model_cls == Gemma3ForConditionalGeneration:
467+
if model_cls.__name__ == "NemotronHForCausalLM":
468+
# need to do something special for nm5, since it's harder to shard the mamba layers
469+
# nm5 is not importable, so we check the __name__ attribute
470+
return _parallelize_nm5_h(
471+
model,
472+
dp_mesh,
473+
tp_mesh,
474+
param_dtype,
475+
sequence_parallel,
476+
activation_checkpointing,
477+
cpu_offload,
478+
custom_parallel_plan,
479+
)
480+
elif model_cls == Gemma3ForConditionalGeneration:
398481
layers: torch.nn.ModuleList = model.language_model.layers # type: ignore
399482
num_attention_heads = model.config.text_config.num_attention_heads
400483
num_key_value_heads = model.config.text_config.num_key_value_heads

nemo_rl/models/policy/dtensor_policy_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def __init__(
260260
with init_empty_weights():
261261
self.model = model_class.from_config(
262262
model_config,
263+
trust_remote_code=True,
263264
)
264265

265266
if self.model.config.pad_token_id is None:

pyproject.toml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,17 @@ automodel = [
5555
# https://github.com/NVIDIA/TransformerEngine/blob/v2.3/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L108
5656
# https://github.com/facebookresearch/xformers/blob/8354497deb2c04c67fbb2e2ad911e86530da0e90/xformers/ops/fmha/flash.py#L76
5757
"flash-attn==2.7.4.post1",
58+
"mamba-ssm",
59+
"causal-conv1d",
5860
]
5961
vllm = [
6062
"vllm==0.10.0",
6163
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
6264
"flash-attn==2.7.4.post1",
65+
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
66+
"mamba-ssm",
67+
# Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved
68+
"causal-conv1d",
6369
]
6470
mcore = [
6571
# also need cudnn (https://developer.nvidia.com/cudnn-downloads?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=deb_network)
@@ -132,6 +138,8 @@ torchvision = [
132138
triton = [
133139
{ index = "pytorch-cu128" },
134140
]
141+
causal-conv1d = { git = "https://github.com/Dao-AILab/causal-conv1d", tag = "v1.5.0.post8" }
142+
mamba-ssm = { git = "https://github.com/state-spaces/mamba.git", rev = "2e16fc3062cdcd4ebef27a9aa4442676e1c7edf4" }
135143

136144
[tool.uv.workspace]
137145
members = [
@@ -145,7 +153,7 @@ url = "https://download.pytorch.org/whl/cu128"
145153
explicit = true
146154

147155
[tool.uv]
148-
no-build-isolation-package = ["transformer-engine-torch", "transformer-engine", "flash-attn"]
156+
no-build-isolation-package = ["transformer-engine-torch", "transformer-engine", "flash-attn", "mamba-ssm", "causal-conv1d"]
149157
# Always apply the build group since dependencies like TE/mcore/nemo-run require build dependencies
150158
# and this lets us assume they are implicitly installed with a simply `uv sync`. Ideally, we'd
151159
# avoid including these in the default dependency set, but for now it's required.

tests/unit/L0_Unit_Tests_Generation.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#!/bin/bash
1616
set -xeuo pipefail # Exit immediately if a command exits with a non-zero status
1717

18+
uv run tests/unit/prepare_unit_test_assets.py
19+
1820
cd /opt/nemo-rl
1921
uv run --no-sync bash -x ./tests/run_unit.sh unit/models/generation/ --cov=nemo_rl --cov-report=term-missing --cov-report=json --hf-gated
2022

tests/unit/L0_Unit_Tests_Other.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#!/bin/bash
1616
set -xeuo pipefail # Exit immediately if a command exits with a non-zero status
1717

18+
uv run tests/unit/prepare_unit_test_assets.py
19+
1820
cd /opt/nemo-rl
1921
uv run --no-sync bash -x ./tests/run_unit.sh unit/ --ignore=unit/models/generation/ --ignore=unit/models/policy/ --cov=nemo_rl --cov-report=term-missing --cov-report=json --hf-gated
2022

tests/unit/L0_Unit_Tests_Policy.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#!/bin/bash
1616
set -xeuo pipefail # Exit immediately if a command exits with a non-zero status
1717

18+
uv run tests/unit/prepare_unit_test_assets.py
19+
1820
cd /opt/nemo-rl
1921
uv run --no-sync bash -x ./tests/run_unit.sh unit/models/policy/ --cov=nemo_rl --cov-report=term-missing --cov-report=json --hf-gated
2022

tests/unit/conftest.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,3 +576,49 @@ def tiny_gemma3_model_path():
576576
tokenizer.save_pretrained(model_path)
577577
del model, tokenizer
578578
yield model_path
579+
580+
581+
def _build_tiny_nemotron5_h_checkpoint(model_path: str) -> None:
582+
import shutil
583+
584+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
585+
586+
config = AutoConfig.from_pretrained(
587+
"nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True
588+
)
589+
config.hybrid_override_pattern = "M*-"
590+
config.num_hidden_layers = 3
591+
config.intermediate_size = 32
592+
config.hidden_size = 256
593+
config.num_attention_heads = 8
594+
config.mamba_num_heads = 8
595+
config.num_key_value_heads = 8
596+
config.n_groups = 1
597+
598+
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
599+
tokenizer = AutoTokenizer.from_pretrained(
600+
"nvidia/Nemotron-H-8B-Base-8K", trust_remote_code=True
601+
)
602+
603+
shutil.rmtree(model_path, ignore_errors=True)
604+
model.save_pretrained(model_path)
605+
tokenizer.save_pretrained(model_path)
606+
607+
608+
@pytest.fixture(scope="session")
609+
def tiny_nemotron5_h_model_path():
610+
"""Fixture that returns a path to a tiny nemotron model with a dummy tokenizer.
611+
612+
If the asset hasn't been prepared by the prepare script, skip the tests that require it.
613+
"""
614+
model_path = os.path.join(
615+
TEST_ASSETS_DIR, "tiny_nemotron5_h_with_nemotron_tokenizer"
616+
)
617+
618+
config_file = os.path.join(model_path, "config.json")
619+
if not os.path.exists(config_file):
620+
pytest.skip(
621+
"Tiny Nemotron-H test asset not prepared. Run `uv run tests/unit/prepare_unit_test_assets.py` first."
622+
)
623+
624+
yield model_path

tests/unit/models/policy/test_dtensor_worker.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,13 @@ def training_setup(request, two_gpu_virtual_cluster):
328328
("tiny_gemma3_model_path", 1, 1, False, True, True),
329329
("tiny_gemma3_model_path", 1, 1, True, True, True),
330330
# CP doesn't support gemma3 due to spda input has attent_mask != None.
331+
# Nemotron-H doesn't support SP https://github.com/NVIDIA-NeMo/RL/issues/881
332+
# ("tiny_nemotron5_h_model_path", 1, 1, True, True, False),
333+
# ("tiny_nemotron5_h_model_path", 1, 1, True, False, True),
334+
# ("tiny_nemotron5_h_model_path", 1, 1, True, True, True),
335+
("tiny_nemotron5_h_model_path", 1, 1, False, False, False),
336+
("tiny_nemotron5_h_model_path", 1, 1, False, True, True),
337+
# nemotron5_h doesn't support cp
331338
],
332339
indirect=True,
333340
)

0 commit comments

Comments
 (0)