Skip to content

Commit 669e70c

Browse files
adil-ahemildesai
andauthored
feat: DTensorPolicyV2 GPT-OSS SFT support (#1470)
Signed-off-by: Hemil Desai <[email protected]> Co-authored-by: Hemil Desai <[email protected]>
1 parent 56e8fcb commit 669e70c

22 files changed

+3282
-1368
lines changed
Submodule Automodel updated 477 files
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
defaults: ../../sft.yaml
2+
cluster:
3+
gpus_per_node: 8
4+
policy:
5+
model_name: openai/gpt-oss-20b
6+
train_global_batch_size: 128
7+
train_micro_batch_size: 8
8+
max_total_sequence_length: 512
9+
dequantize_base_checkpoint: true
10+
dtensor_cfg:
11+
expert_parallel_size: 8
12+
automodel_kwargs:
13+
backend:
14+
_target_: nemo_automodel.components.moe.utils.BackendConfig
15+
attn: flex
16+
linear: te
17+
rms_norm: te
18+
enable_deepep: true
19+
fake_balanced_gate: false
20+
enable_hf_state_dict_adapter: true
21+
checkpointing:
22+
checkpoint_dir: results/sft-gpt-oss-20b-1n8g-fsdp8ep8-automodel

nemo_rl/models/policy/__init__.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,34 @@ class LoRAConfig(TypedDict):
3434
use_triton: NotRequired[bool]
3535

3636

37+
class AutomodelBackendConfig(TypedDict):
38+
# Hydra target class path (e.g., "nemo_automodel.components.moe.utils.BackendConfig")
39+
_target_: str
40+
# Attention implementation: "te" (Transformer Engine), "flex" (FlexAttention), etc.
41+
attn: NotRequired[str]
42+
# Linear layer implementation: "te" (Transformer Engine), etc.
43+
linear: NotRequired[str]
44+
# RMSNorm implementation: "te" (Transformer Engine), etc.
45+
rms_norm: NotRequired[str]
46+
# Enable DeepEP (Deep Expert Parallelism) for MoE models
47+
enable_deepep: NotRequired[bool]
48+
# Use fake balanced gate for testing/debugging MoE
49+
fake_balanced_gate: NotRequired[bool]
50+
# Enable HuggingFace state dict adapter for checkpoint loading
51+
enable_hf_state_dict_adapter: NotRequired[bool]
52+
# Enable FSDP-specific optimizations
53+
enable_fsdp_optimizations: NotRequired[bool]
54+
# Precision for the MoE gate computation (e.g., "float64", "float32")
55+
gate_precision: NotRequired[str]
56+
57+
58+
class AutomodelKwargs(TypedDict):
59+
# Whether to use Liger kernel optimizations (default: false)
60+
use_liger_kernel: NotRequired[bool]
61+
# Backend configuration for MoE models
62+
backend: NotRequired[AutomodelBackendConfig]
63+
64+
3765
class DTensorConfigDisabled(TypedDict):
3866
enabled: Literal[False]
3967

@@ -50,6 +78,7 @@ class DTensorConfig(TypedDict):
5078
custom_parallel_plan: str | None
5179
clear_cache_every_n_steps: NotRequired[int | None]
5280
lora_cfg: NotRequired[LoRAConfig | LoRAConfigDisabled]
81+
automodel_kwargs: NotRequired[AutomodelKwargs]
5382

5483

5584
class SequencePackingConfigDisabled(TypedDict):

nemo_rl/models/policy/lm_policy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ def __init__(
111111
use_v2 = config.get("dtensor_cfg", {}).get("_v2", False)
112112
if use_v2:
113113
worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2"
114+
115+
if "TORCH_CUDA_ARCH_LIST" not in os.environ:
116+
warnings.warn(
117+
"TORCH_CUDA_ARCH_LIST is not set. This is needed if using DeepEP in DTensorPolicyWorker V2. This variable is set in our container, but "
118+
"if you are running a custom container or baremetal, you may need to set this variable manually. Example: export TORCH_CUDA_ARCH_LIST='9.0 10.0'"
119+
)
114120
else:
115121
assert (
116122
config["dtensor_cfg"].get("lora_cfg", {}).get("enabled", False)

nemo_rl/models/policy/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
# Try to import nemo_automodel classes, fallback to None if not available
3131
try:
32-
from nemo_automodel.components._transformers.auto_model import (
32+
from nemo_automodel._transformers.auto_model import (
3333
NeMoAutoModelForCausalLM,
3434
NeMoAutoModelForImageTextToText,
3535
NeMoAutoModelForTextToWaveform,

nemo_rl/models/policy/workers/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)