@@ -478,23 +478,6 @@ def get_non_persistent_buffers(module: nn.Module, recurse: bool = False, fqns: b
478478 return non_persistent_buffers_set
479479
480480
481- class FindTiedParametersResult (list ):
482- """
483- This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not
484- a list or on the `values` method as in the future this will be removed.
485- """
486-
487- def __init__ (self , * args , ** kwargs ):
488- super ().__init__ (* args , ** kwargs )
489-
490- def values (self ):
491- warnings .warn (
492- "The 'values' method of FindTiedParametersResult is deprecated and will be removed in Accelerate v1.3.0. " ,
493- FutureWarning ,
494- )
495- return sum ([x [1 :] for x in self ], [])
496-
497-
498481def check_tied_parameters_in_config (model : nn .Module ):
499482 """
500483 Check if there is any indication in the given model that some weights should be tied.
@@ -568,7 +551,7 @@ def check_tied_parameters_on_same_device(tied_params, device_map):
568551 )
569552
570553
571- def find_tied_parameters (model : torch .nn .Module , ** kwargs ):
554+ def find_tied_parameters (model : torch .nn .Module , ** kwargs ) -> list [ list [ str ]] :
572555 """
573556 Find the tied parameters in a given model.
574557
@@ -620,7 +603,7 @@ def find_tied_parameters(model: torch.nn.Module, **kwargs):
620603 tied_param_groups [param_name ] = []
621604 tied_param_groups [param_name ].append (tied_param_name )
622605
623- return FindTiedParametersResult ( [sorted ([weight ] + list (set (tied ))) for weight , tied in tied_param_groups .items ()])
606+ return [sorted ([weight ] + list (set (tied ))) for weight , tied in tied_param_groups .items ()]
624607
625608
626609def retie_parameters (model , tied_params ):
0 commit comments