|
42 | 42 | AutoModelForSequenceClassification, |
43 | 43 | AutoTokenizer, |
44 | 44 | ) |
45 | | -from transformers.integrations.accelerate import find_tied_parameters |
46 | 45 | from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM |
47 | 46 |
|
48 | 47 | from nemo_rl.algorithms.interfaces import LossFunction, LossType |
|
56 | 55 | to_local_if_dtensor, |
57 | 56 | ) |
58 | 57 | from nemo_rl.models.huggingface.common import ( |
59 | | - ModelFlag, |
60 | 58 | get_flash_attention_kwargs, |
61 | 59 | pack_sequences, |
62 | 60 | ) |
@@ -267,12 +265,8 @@ def __init__( |
267 | 265 | self.model.config.pad_token_id = tokenizer.pad_token_id |
268 | 266 |
|
269 | 267 | # caching since this property is not always preserved after FSDP |
270 | | - self.num_tied_weights = len(find_tied_parameters(self.model)) |
271 | | - self.skip_tie_check = os.environ.get( |
272 | | - "NRL_SKIP_TIED_WEIGHT_CHECK" |
273 | | - ) or ModelFlag.SKIP_DTENSOR_TIED_WEIGHTS_CHECK.matches(model_name) |
274 | | - |
275 | 268 | self.tokenizer = tokenizer |
| 269 | + |
276 | 270 | # ------------------------------------------------ |
277 | 271 | # 3) Move to GPU + Composable FSDP |
278 | 272 | # (Initialize device mesh, shard submodules, then shard entire model) |
@@ -528,15 +522,6 @@ def train( |
528 | 522 | mbs: Optional[int] = None, |
529 | 523 | ) -> dict[str, Any]: |
530 | 524 | """Train the policy on a batch of data with a given loss function.""" |
531 | | - # Check if the model has tied weights |
532 | | - if ( |
533 | | - self.num_tied_weights != 0 |
534 | | - and self.cfg["dtensor_cfg"]["tensor_parallel_size"] > 1 |
535 | | - and not self.skip_tie_check |
536 | | - ): |
537 | | - raise ValueError( |
538 | | - f"Using dtensor policy with tp size {self.cfg['dtensor_cfg']['tensor_parallel_size']} for model ({self.cfg['model_name']}) that has tied weights (num_tied_weights={self.num_tied_weights}) is not supported (https://github.com/NVIDIA-NeMo/RL/issues/227). Please use dtensor policy with tensor parallel == 1 instead." |
539 | | - ) |
540 | 525 | if gbs is None: |
541 | 526 | gbs = self.cfg["train_global_batch_size"] |
542 | 527 | if mbs is None: |
|
0 commit comments