@@ -255,7 +255,7 @@ def should_merge_last_two_dim(self) -> bool:
255
255
"""check that wether merge last two dim"""
256
256
return self .action == "merge_last_two_dim"
257
257
258
- def run (self , tensor : ndarray ) -> ndarray :
258
+ def run (self , state_dict : dict [ str , ndarray ], name : str ) -> ndarray :
259
259
"""run some custom operation on ndarray, eg: transpose, merge_last_two_dim
260
260
261
261
Args:
@@ -264,12 +264,21 @@ def run(self, tensor: ndarray) -> ndarray:
264
264
Returns:
265
265
ndarray: the final tensor
266
266
"""
267
+ tensor = state_dict .pop (name )
267
268
if self .action == "transpose" :
268
269
return transpose (tensor , [1 , 0 ])
269
270
if self .action == "merge_last_two_dim" :
270
271
shape = tensor .shape
271
272
assert len (shape ) == 3
272
273
return np .reshape (tensor , [shape [0 ], - 1 ])
274
+ if self .action == "split" :
275
+ assert self .index is not None , "when action is `split`, index field is required."
276
+
277
+ if self .index < 2 :
278
+ state_dict [name ] = tensor
279
+ # qkv is stored in same tensor, so it should be split into 3 arr
280
+ tensors = np .split (tensor , 3 , axis = - 1 )
281
+ return tensors [self .index ]
273
282
return tensor
274
283
275
284
def matched (self , text : str ) -> bool :
@@ -490,6 +499,9 @@ class LogitComparer:
490
499
config_fields_to_be_removed : List [str ] = ["transformers_version" ]
491
500
architectures : Dict [str , Type [PretrainedModel ]] = {}
492
501
502
+ def __init__ (self , input_dir : str ) -> None :
503
+ self .input_dir = input_dir
504
+
493
505
def get_paddle_pytorch_model_classes (self ) -> Tuple [object , object ]:
494
506
"""return the [PaddleModelClass, PytorchModelClass] to
495
507
1. generate paddle model automatically
@@ -574,13 +586,15 @@ def compare_model_state_dicts(
574
586
for name_mapping in name_mappings :
575
587
model_state_saver .add (name_mapping .target_name , "pytorch_key" , name_mapping .source_name )
576
588
577
- paddle_numpy = paddle_state_dict .pop (name_mapping .target_name )
578
- model_state_saver .add (name_mapping .target_name , "paddle" , paddle_numpy )
579
- model_state_saver .add (name_mapping .target_name , "paddle-shape" , str (paddle_numpy .shape ))
589
+ if name_mapping .target_name in paddle_state_dict :
590
+ paddle_numpy = paddle_state_dict .pop (name_mapping .target_name )
591
+ model_state_saver .add (name_mapping .target_name , "paddle" , paddle_numpy )
592
+ model_state_saver .add (name_mapping .target_name , "paddle-shape" , str (paddle_numpy .shape ))
580
593
581
- pytorch_numpy = pytorch_state_dict .pop (name_mapping .source_name )
582
- model_state_saver .add (name_mapping .target_name , "pytorch" , pytorch_numpy )
583
- model_state_saver .add (name_mapping .target_name , "pytorch-shape" , str (pytorch_numpy .shape ))
594
+ if name_mapping .source_name in pytorch_state_dict :
595
+ pytorch_numpy = pytorch_state_dict .pop (name_mapping .source_name )
596
+ model_state_saver .add (name_mapping .target_name , "pytorch" , pytorch_numpy )
597
+ model_state_saver .add (name_mapping .target_name , "pytorch-shape" , str (pytorch_numpy .shape ))
584
598
585
599
model_state_saver .summary ()
586
600
@@ -594,8 +608,7 @@ def compare_logits(self) -> bool:
594
608
paddle_model = PaddleModel .from_pretrained (self .input_dir )
595
609
596
610
# 0. init the name_mapping & tensor_info_saver & logit_hooker
597
- num_layers = self .get_num_layer (list (paddle_model .state_dict ().keys ()))
598
- name_mappings = self .get_name_mapping (num_layers , paddle_model .config ["architectures" ])
611
+ name_mappings = self .get_name_mapping (paddle_model .config )
599
612
tensor_info_saver = TensorInfoSaver ()
600
613
601
614
logit_hooker = LogitHooker (name_mappings , tensor_info_saver )
@@ -707,8 +720,9 @@ def convert(cls, weight_file: str, config: PretrainedConfig, cache_dir: str) ->
707
720
logger .warning (f"key<{ name_mapping .source_name } > not in the pytorch weight file." )
708
721
continue
709
722
710
- state_dict [name_mapping .target_name ] = name_mapping .run (state_dict .pop (name_mapping .source_name ))
711
- all_layer_names .remove (name_mapping .source_name )
723
+ state_dict [name_mapping .target_name ] = name_mapping .run (state_dict , name_mapping .source_name )
724
+ if name_mapping .source_name in all_layer_names :
725
+ all_layer_names .remove (name_mapping .source_name )
712
726
713
727
if all_layer_names :
714
728
logger .warning (f"there are { len (all_layer_names )} tensors not initialized:" )
0 commit comments