File tree Expand file tree Collapse file tree 2 files changed +4
-4
lines changed
nemo_reinforcer/models/policy Expand file tree Collapse file tree 2 files changed +4
-4
lines changed Original file line number Diff line number Diff line change 2525 FSDPModule ,
2626)
2727from transformers import AutoModelForCausalLM , AutoTokenizer
28- from transformers .modeling_utils import _get_tied_weight_keys
28+ from transformers .integrations . accelerate import find_tied_parameters
2929from nemo_reinforcer .models .dtensor .parallelize import _parallelize_model
3030
3131from nemo_reinforcer .algorithms .interfaces import LossFunction
@@ -256,7 +256,7 @@ def train(
256256 mbs : Optional [int ] = None ,
257257 ) -> Dict [str , Any ]:
258258 """Train the policy on a batch of data with a given loss function."""
259- num_tied_weights = len (_get_tied_weight_keys (self .model ))
259+ num_tied_weights = len (find_tied_parameters (self .model ))
260260 skip_tie_check = os .environ .get ("NRL_SKIP_TIED_WEIGHT_CHECK" )
261261 if (
262262 num_tied_weights != 0
Original file line number Diff line number Diff line change 3939)
4040
4141from transformers import AutoModelForCausalLM , AutoTokenizer
42- from transformers .modeling_utils import _get_tied_weight_keys
42+ from transformers .integrations . accelerate import find_tied_parameters
4343from nemo_reinforcer .models .policy import PolicyConfig
4444from nemo_reinforcer .models .policy .utils import import_class_from_path
4545from nemo_reinforcer .distributed .virtual_cluster import (
@@ -229,7 +229,7 @@ def train(
229229 ) -> Dict [str , Any ]:
230230 """Train the policy on a batch of data with a given loss function."""
231231 # Check if the model has tied weights
232- num_tied_weights = len (_get_tied_weight_keys (self .model ))
232+ num_tied_weights = len (find_tied_parameters (self .model ))
233233 skip_tie_check = os .environ .get ("NRL_SKIP_TIED_WEIGHT_CHECK" )
234234 if num_tied_weights != 0 and not skip_tie_check :
235235 raise ValueError (
You can’t perform that action at this time.
0 commit comments