Skip to content

Commit 1c7cbd9

Browse files
authored
fix: use find_tied_parameters api from HF for tied weight keys (#250)
Signed-off-by: Parth Chadha <pchadha@nvidia.com>
1 parent 1788e4c commit 1c7cbd9

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

nemo_reinforcer/models/policy/dtensor_policy_worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
FSDPModule,
2626
)
2727
from transformers import AutoModelForCausalLM, AutoTokenizer
28-
from transformers.modeling_utils import _get_tied_weight_keys
28+
from transformers.integrations.accelerate import find_tied_parameters
2929
from nemo_reinforcer.models.dtensor.parallelize import _parallelize_model
3030

3131
from nemo_reinforcer.algorithms.interfaces import LossFunction
@@ -256,7 +256,7 @@ def train(
256256
mbs: Optional[int] = None,
257257
) -> Dict[str, Any]:
258258
"""Train the policy on a batch of data with a given loss function."""
259-
num_tied_weights = len(_get_tied_weight_keys(self.model))
259+
num_tied_weights = len(find_tied_parameters(self.model))
260260
skip_tie_check = os.environ.get("NRL_SKIP_TIED_WEIGHT_CHECK")
261261
if (
262262
num_tied_weights != 0

nemo_reinforcer/models/policy/fsdp1_policy_worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
)
4040

4141
from transformers import AutoModelForCausalLM, AutoTokenizer
42-
from transformers.modeling_utils import _get_tied_weight_keys
42+
from transformers.integrations.accelerate import find_tied_parameters
4343
from nemo_reinforcer.models.policy import PolicyConfig
4444
from nemo_reinforcer.models.policy.utils import import_class_from_path
4545
from nemo_reinforcer.distributed.virtual_cluster import (
@@ -229,7 +229,7 @@ def train(
229229
) -> Dict[str, Any]:
230230
"""Train the policy on a batch of data with a given loss function."""
231231
# Check if the model has tied weights
232-
num_tied_weights = len(_get_tied_weight_keys(self.model))
232+
num_tied_weights = len(find_tied_parameters(self.model))
233233
skip_tie_check = os.environ.get("NRL_SKIP_TIED_WEIGHT_CHECK")
234234
if num_tied_weights != 0 and not skip_tie_check:
235235
raise ValueError(

0 commit comments

Comments
 (0)