Skip to content

Commit fecf71e

Browse files
authored
fix: remove tie weight check (#700)
Signed-off-by: ruit <[email protected]>
1 parent d45ff3f commit fecf71e

File tree

12 files changed

+3
-94
lines changed

12 files changed

+3
-94
lines changed

docs/model-quirks.md

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,6 @@ This document outlines special cases and model-specific behaviors that require c
44

55
## Gemma-3
66

7-
### Tied Weights
8-
9-
Weight tying between the embedding layer (`model.embed_tokens`) and output layer (`lm_head`) is currently not respected when using the DTensor policy when TP > 1 (See [this issue](https://github.com/NVIDIA-NeMo/RL/issues/227)). To avoid errors when training these models, we only allow training models with tied weights using the DTensor policy with TP=1. For Llama-3 and Qwen2.5 models, weight-tying is only enabled for the smaller models (< 2B), which can typically be trained without tensor parallelism. For Gemma-3, all model sizes have weight-tying enabled, including the larger models which require tensor parallelism. To support training of these models, we specially handle the Gemma-3 models by allowing training using the DTensor policy with TP > 1.
10-
11-
**Special Handling:**
12-
- We skip the tied weights check for all Gemma-3 models when using the DTensor policy, allowing training using TP > 1.
13-
- We exclude `model.embed_tokens` and `lm_head` from the DTensor tensor parallel plan to maintain weight tying correctly.
14-
157
### vLLM Initialization
168

179
Gemma-3 models have a specific issue with vLLM dummy weight initialization due to a vLLM bug where [a `normalizer` buffer is created](https://github.com/vllm-project/vllm/blob/964472b9667508b1d4a7ed92068ff81740ae0036/vllm/model_executor/models/gemma3.py#L372) that is not present in the Hugging Face model. This causes the `normalizer` buffer to be set to dummy weights at initialization and then never updated with the correct values during model refit. As a workaround for this issue, we do not use dummy weight initialization for vLLM with Gemma-3 models and instead use the `load_format="auto"` setting to load the full weights at initialization.

examples/configs/grpo_math_1B.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ checkpointing:
3232
checkpoint_must_save_by: null
3333

3434
policy:
35-
# Qwen/Qwen2.5-1.5B has tied weights which are only supported with dtensor policy with tp size 1 (https://github.com/NVIDIA-NeMo/RL/issues/227)
3635
model_name: "Qwen/Qwen2.5-1.5B"
3736
tokenizer:
3837
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default

examples/configs/recipes/llm/grpo-deepscaler-1.5b-8K.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ checkpointing:
3232
checkpoint_must_save_by: null
3333

3434
policy:
35-
# Qwen/Qwen2.5-1.5B has tied weights which are only supported with dtensor policy with tp size 1 (https://github.com/NVIDIA-NeMo/RL/issues/227)
3635
model_name: "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
3736
tokenizer:
3837
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default

nemo_rl/models/dtensor/parallelize.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -342,19 +342,12 @@ def get_hf_tp_plan(model: PreTrainedModel):
342342
)
343343

344344
# hf tp plan not contain embed_tokens, we add it and set to rowwise_rep
345-
if (
346-
f"{model_prefix}.embed_tokens" not in hf_tp_plan
347-
and not model.config.tie_word_embeddings
348-
):
345+
if f"{model_prefix}.embed_tokens" not in hf_tp_plan:
349346
hf_tp_plan[f"{model_prefix}.embed_tokens"] = "rowwise_rep"
350347

351348
for k, v in hf_tp_plan.items():
352349
# speed up the tp plan for lm_head
353-
if (
354-
k == "lm_head"
355-
and v == "colwise_rep"
356-
and not model.config.tie_word_embeddings
357-
):
350+
if k == "lm_head" and v == "colwise_rep":
358351
hf_tp_plan[k] = ColwiseParallel(
359352
output_layouts=Shard(-1), use_local_output=False
360353
)

nemo_rl/models/huggingface/common.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,16 @@ class ModelFlag(Enum):
3939
configuration in different parts of the NeMo RL codebase.
4040
4141
Flags:
42-
SKIP_DTENSOR_TIED_WEIGHTS_CHECK: Models that should skip the tied weights check
43-
for the DTensor Policy even without setting the
44-
NRL_SKIP_TIED_WEIGHT_CHECK flag.
4542
VLLM_LOAD_FORMAT_AUTO: Models that should use the "auto" load format when initializing
4643
VLLM.
4744
4845
Each flag has a `matches` method that determines if the flag applies to a given model_name.
4946
"""
5047

51-
SKIP_DTENSOR_TIED_WEIGHTS_CHECK = auto()
5248
VLLM_LOAD_FORMAT_AUTO = auto()
5349

5450
def matches(self, model_name: str) -> bool:
5551
match self:
56-
case ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK:
57-
return is_gemma_model(model_name)
5852
case ModelFlag.VLLM_LOAD_FORMAT_AUTO:
5953
return is_gemma_model(model_name)
6054
case _:

nemo_rl/models/policy/dtensor_policy_worker.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
AutoModelForSequenceClassification,
4343
AutoTokenizer,
4444
)
45-
from transformers.integrations.accelerate import find_tied_parameters
4645
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
4746

4847
from nemo_rl.algorithms.interfaces import LossFunction, LossType
@@ -56,7 +55,6 @@
5655
to_local_if_dtensor,
5756
)
5857
from nemo_rl.models.huggingface.common import (
59-
ModelFlag,
6058
get_flash_attention_kwargs,
6159
pack_sequences,
6260
)
@@ -267,12 +265,8 @@ def __init__(
267265
self.model.config.pad_token_id = tokenizer.pad_token_id
268266

269267
# caching since this property is not always preserved after FSDP
270-
self.num_tied_weights = len(find_tied_parameters(self.model))
271-
self.skip_tie_check = os.environ.get(
272-
"NRL_SKIP_TIED_WEIGHT_CHECK"
273-
) or ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK.matches(model_name)
274-
275268
self.tokenizer = tokenizer
269+
276270
# ------------------------------------------------
277271
# 3) Move to GPU + Composable FSDP
278272
# (Initialize device mesh, shard submodules, then shard entire model)
@@ -528,15 +522,6 @@ def train(
528522
mbs: Optional[int] = None,
529523
) -> dict[str, Any]:
530524
"""Train the policy on a batch of data with a given loss function."""
531-
# Check if the model has tied weights
532-
if (
533-
self.num_tied_weights != 0
534-
and self.cfg["dtensor_cfg"]["tensor_parallel_size"] > 1
535-
and not self.skip_tie_check
536-
):
537-
raise ValueError(
538-
f"Using dtensor policy with tp size {self.cfg['dtensor_cfg']['tensor_parallel_size']} for model ({self.cfg['model_name']}) that has tied weights (num_tied_weights={self.num_tied_weights}) is not supported (https://github.com/NVIDIA-NeMo/RL/issues/227). Please use dtensor policy with tensor parallel == 1 instead."
539-
)
540525
if gbs is None:
541526
gbs = self.cfg["train_global_batch_size"]
542527
if mbs is None:

tests/unit/models/generation/test_vllm_generation.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -302,17 +302,6 @@ def test_input_data(tokenizer):
302302
)
303303

304304

305-
@pytest.fixture(scope="module", autouse=True)
306-
def skip_tied_weight_check_for_all():
307-
"""Automatically skip tied weight check for all tests in this module."""
308-
os.environ["NRL_SKIP_TIED_WEIGHT_CHECK"] = "1"
309-
310-
yield
311-
312-
# Restore the original value
313-
os.environ.pop("NRL_SKIP_TIED_WEIGHT_CHECK", None)
314-
315-
316305
def test_vllm_missing_required_config_key(cluster):
317306
"""Test that an assertion error is raised when a required config key is missing."""
318307
# Create a config missing a required key by removing 'model_name'

tests/unit/models/generation/test_vllm_large_model.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
1615
from copy import deepcopy
1716

1817
import pytest
@@ -63,14 +62,6 @@
6362
}
6463

6564

66-
@pytest.fixture(scope="module", autouse=True)
67-
def skip_tied_weight_check():
68-
"""Automatically skip tied weight check for all tests in this module."""
69-
os.environ["NRL_SKIP_TIED_WEIGHT_CHECK"] = "1"
70-
yield
71-
os.environ.pop("NRL_SKIP_TIED_WEIGHT_CHECK", None)
72-
73-
7465
@pytest.fixture(scope="function")
7566
def two_node_cluster():
7667
"""Create a virtual cluster with 2 nodes for testing large models."""

tests/unit/models/huggingface/test_common.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
)
4040
def test_gemma_models(model_name):
4141
assert is_gemma_model(model_name)
42-
assert ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK.matches(model_name)
4342
assert ModelFlag.VLLM_LOAD_FORMAT_AUTO.matches(model_name)
4443

4544

@@ -54,5 +53,4 @@ def test_gemma_models(model_name):
5453
)
5554
def test_non_gemma_models(model_name):
5655
assert not is_gemma_model(model_name)
57-
assert not ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK.matches(model_name)
5856
assert not ModelFlag.VLLM_LOAD_FORMAT_AUTO.matches(model_name)

tests/unit/models/policy/test_dtensor_worker.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import os
1514
import pprint
1615

1716
import pytest
@@ -107,17 +106,6 @@ def create_test_config(
107106
}
108107

109108

110-
@pytest.fixture(scope="module", autouse=True)
111-
def skip_tied_weight_check_for_all():
112-
"""Automatically skip tied weight check for all tests in this module."""
113-
os.environ["NRL_SKIP_TIED_WEIGHT_CHECK"] = "1"
114-
115-
yield
116-
117-
# Restore the original value
118-
os.environ.pop("NRL_SKIP_TIED_WEIGHT_CHECK", None)
119-
120-
121109
@pytest.fixture(scope="module")
122110
def two_gpu_virtual_cluster():
123111
cluster_name = "test"

0 commit comments

Comments
 (0)