Skip to content

Commit 1438331

Browse files
authored
Remove deprecated FindTiedParametersResult (#3786)
Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
1 parent a737437 commit 1438331

File tree

1 file changed

+2
-19
lines changed

1 file changed

+2
-19
lines changed

src/accelerate/utils/modeling.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -478,23 +478,6 @@ def get_non_persistent_buffers(module: nn.Module, recurse: bool = False, fqns: b
478478
return non_persistent_buffers_set
479479

480480

481-
class FindTiedParametersResult(list):
482-
"""
483-
This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not
484-
a list or on the `values` method as in the future this will be removed.
485-
"""
486-
487-
def __init__(self, *args, **kwargs):
488-
super().__init__(*args, **kwargs)
489-
490-
def values(self):
491-
warnings.warn(
492-
"The 'values' method of FindTiedParametersResult is deprecated and will be removed in Accelerate v1.3.0. ",
493-
FutureWarning,
494-
)
495-
return sum([x[1:] for x in self], [])
496-
497-
498481
def check_tied_parameters_in_config(model: nn.Module):
499482
"""
500483
Check if there is any indication in the given model that some weights should be tied.
@@ -568,7 +551,7 @@ def check_tied_parameters_on_same_device(tied_params, device_map):
568551
)
569552

570553

571-
def find_tied_parameters(model: torch.nn.Module, **kwargs):
554+
def find_tied_parameters(model: torch.nn.Module, **kwargs) -> list[list[str]]:
572555
"""
573556
Find the tied parameters in a given model.
574557
@@ -620,7 +603,7 @@ def find_tied_parameters(model: torch.nn.Module, **kwargs):
620603
tied_param_groups[param_name] = []
621604
tied_param_groups[param_name].append(tied_param_name)
622605

623-
return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()])
606+
return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]
624607

625608

626609
def retie_parameters(model, tied_params):

0 commit comments

Comments
 (0)