@@ -646,48 +646,49 @@ def forward(
646
646
def _load_from_state_dict (
647
647
self , state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs
648
648
):
649
- # NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
650
- for hook in self ._load_state_dict_pre_hooks .values ():
651
- hook (state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs )
652
-
653
- persistent_buffers = {k : v for k , v in self ._buffers .items () if k not in self ._non_persistent_buffers_set }
654
- local_name_params = itertools .chain (self ._parameters .items (), persistent_buffers .items ())
655
- local_state = {k : v for k , v in local_name_params if v is not None }
656
-
657
- key = "qkv_weight"
658
- k1 = "q_proj.weight"
659
- k2 = "k_proj.weight"
660
- k3 = "v_proj.weight"
661
- q_w = state_dict [prefix + k1 ]
662
- k_w = state_dict [prefix + k2 ]
663
- v_w = state_dict [prefix + k3 ]
664
-
665
- device_mesh = self .helper_layout .device_mesh
666
- sharding_spec = self .helper_layout .sharding_spec
667
- q_w = distribute_tensor (q_w , device_mesh , sharding_spec )
668
- k_w = distribute_tensor (k_w , device_mesh , sharding_spec )
669
- v_w = distribute_tensor (v_w , device_mesh , sharding_spec )
670
-
671
- qkv_w = torch .stack ([q_w .T , k_w .T , v_w .T ], dim = 0 )
672
-
673
- input_param = nn .Parameter (
674
- qkv_w
675
- ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
676
-
677
- param = local_state [key ]
678
-
679
- try :
680
- with torch .no_grad ():
681
- param .copy_ (input_param )
682
- except Exception as ex :
683
- error_msgs .append (
684
- 'While copying the parameter named "{}", '
685
- "whose dimensions in the model are {} and "
686
- "whose dimensions in the checkpoint are {}, "
687
- "an exception occurred : {}." .format (key , param .size (), input_param .size (), ex .args )
688
- )
649
+ if self .num_heads == self .num_key_value_heads :
650
+ # NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
651
+ for hook in self ._load_state_dict_pre_hooks .values ():
652
+ hook (state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs )
653
+
654
+ persistent_buffers = {k : v for k , v in self ._buffers .items () if k not in self ._non_persistent_buffers_set }
655
+ local_name_params = itertools .chain (self ._parameters .items (), persistent_buffers .items ())
656
+ local_state = {k : v for k , v in local_name_params if v is not None }
657
+
658
+ key = "qkv_weight"
659
+ k1 = "q_proj.weight"
660
+ k2 = "k_proj.weight"
661
+ k3 = "v_proj.weight"
662
+ q_w = state_dict [prefix + k1 ]
663
+ k_w = state_dict [prefix + k2 ]
664
+ v_w = state_dict [prefix + k3 ]
665
+
666
+ device_mesh = self .helper_layout .device_mesh
667
+ sharding_spec = self .helper_layout .sharding_spec
668
+ q_w = distribute_tensor (q_w , device_mesh , sharding_spec )
669
+ k_w = distribute_tensor (k_w , device_mesh , sharding_spec )
670
+ v_w = distribute_tensor (v_w , device_mesh , sharding_spec )
671
+
672
+ qkv_w = torch .stack ([q_w .T , k_w .T , v_w .T ], dim = 0 )
673
+
674
+ input_param = nn .Parameter (
675
+ qkv_w
676
+ ) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
677
+
678
+ param = local_state [key ]
679
+
680
+ try :
681
+ with torch .no_grad ():
682
+ param .copy_ (input_param )
683
+ except Exception as ex :
684
+ error_msgs .append (
685
+ 'While copying the parameter named "{}", '
686
+ "whose dimensions in the model are {} and "
687
+ "whose dimensions in the checkpoint are {}, "
688
+ "an exception occurred : {}." .format (key , param .size (), input_param .size (), ex .args )
689
+ )
689
690
690
- strict = False # to avoid unexpected_keys
691
+ strict = False # to avoid unexpected_keys
691
692
super ()._load_from_state_dict (
692
693
state_dict , prefix , local_metadata , strict , missing_keys , unexpected_keys , error_msgs
693
694
)
0 commit comments