@@ -2318,7 +2318,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
23182318
23192319                lora_A_weight_name  =  f"{ name }  
23202320                lora_B_weight_name  =  f"{ name }  
2321-                 lora_B_bias_name  =  f"{ name }  
2321+                 #  lora_B_bias_name = f"{name}.lora_B.bias"
23222322
23232323                if  lora_A_weight_name  not  in state_dict .keys ():
23242324                    continue 
@@ -2352,24 +2352,15 @@ def _maybe_expand_transformer_param_shape_or_error_(
23522352                expanded_module  =  torch .nn .Linear (
23532353                    in_features , out_features , bias = bias , device = module_weight .device , dtype = module_weight .dtype 
23542354                )
2355- 
2355+                  # Only weights are expanded and biases are not. 
23562356                new_weight  =  torch .zeros_like (
23572357                    expanded_module .weight .data , device = module_weight .device , dtype = module_weight .dtype 
23582358                )
23592359                slices  =  tuple (slice (0 , dim ) for  dim  in  module_weight .shape )
23602360                new_weight [slices ] =  module_weight 
23612361                expanded_module .weight .data .copy_ (new_weight )
2362- 
2363-                 bias_present_for_lora_B  =  lora_B_bias_name  in  state_dict 
2364-                 if  bias_present_for_lora_B :
2365-                     new_bias_shape  =  state_dict [lora_B_bias_name ].shape 
2366-                     if  bias  and  module_bias .shape  <  new_bias_shape :
2367-                         new_bias  =  torch .zeros_like (
2368-                             expanded_module .bias .data , device = module_bias .device , dtype = module_bias .dtype 
2369-                         )
2370-                         slices  =  tuple (slice (0 , dim ) for  dim  in  module_bias .shape )
2371-                         new_bias [slices ] =  module_bias 
2372-                         expanded_module .bias .data .copy_ (new_bias )
2362+                 if  module_bias  is  not None :
2363+                     expanded_module .bias .data .copy_ (module_bias )
23732364
23742365                setattr (parent_module , current_module_name , expanded_module )
23752366
0 commit comments