@@ -46,7 +46,6 @@ def count_parameters(self, log=print):
4646 raise NotImplementedError
4747
4848 def _load_param (self , name , module , param_name , param , buf_list , ckpt_t , offset ):
49-
5049 if name .split ("." )[- 1 ] in [
5150 "linear_keys" ,
5251 "linear_values" ,
@@ -73,7 +72,7 @@ def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset)
7372 row_slice_start :row_slice_end ,
7473 ].size ()
7574 ), "An error in model's partition and checkpoint's slice was detected"
76- if param_name in buf_list :
75+ if name + "." + param_name in buf_list :
7776 module .register_buffer (
7877 param_name ,
7978 ckpt_t [
@@ -90,7 +89,7 @@ def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset)
9089 assert (
9190 param .data .size () == ckpt_t [col_slice_start :col_slice_end ].size ()
9291 ), "An error in model's partition and checkpoint's slice was detected"
93- if param_name in buf_list :
92+ if name + "." + param_name in buf_list :
9493 module .register_buffer (
9594 param_name , ckpt_t [col_slice_start :col_slice_end ]
9695 )
@@ -120,9 +119,9 @@ def load_state_dict(
120119 if device == torch .device ("cpu" ):
121120 offset = 0
122121 buf_list = []
122+ for buf_name , buf in self .named_buffers ():
123+ buf_list .append (buf_name )
123124 for name , module in self .named_modules ():
124- for buf_name , buf in module .named_buffers ():
125- buf_list .append (buf_name )
126125 named_buf_and_param = list (module .named_buffers ()) + list (
127126 module .named_parameters ()
128127 )
@@ -205,9 +204,9 @@ def load_safe_state_dict(
205204 if device == torch .device ("cpu" ):
206205 offset = 0
207206 buf_list = []
207+ for buf_name , buf in self .named_buffers ():
208+ buf_list .append (buf_name )
208209 for name , module in self .named_modules ():
209- for buf_name , buf in module .named_buffers ():
210- buf_list .append (buf_name )
211210 named_buf_and_param = list (module .named_buffers ()) + list (
212211 module .named_parameters ()
213212 )
0 commit comments