File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -82,17 +82,21 @@ def __init__(
8282
8383 if self .offload_to_disk_path :
8484 all_tensors = []
85+ param_names = []
8586 for module in self .modules :
8687 all_tensors .extend (list (module .parameters ()))
8788 all_tensors .extend (list (module .buffers ()))
89+
90+ param_names .extend ([param_name for param_name , _ in module .named_parameters ()])
91+
8892 all_tensors .extend (self .parameters )
8993 all_tensors .extend (self .buffers )
9094 all_tensors = list (dict .fromkeys (all_tensors )) # Remove duplicates
9195
9296 self .tensor_to_key = {tensor : f"tensor_{ i } " for i , tensor in enumerate (all_tensors )}
9397 self .key_to_tensor = {v : k for k , v in self .tensor_to_key .items ()}
9498
95- group_id_key = "_" .join (sorted ([ param_name for param_name , _ in module . named_parameters ()] ))
99+ group_id_key = "_" .join (sorted (param_names ))
96100 self ._disk_offload_group_id = hashlib .md5 (group_id_key .encode ()).hexdigest ()[:8 ]
97101
98102 self .cpu_param_dict = {}
You can’t perform that action at this time.
0 commit comments