1212# See the License for the specific language governing permissions and 
1313# limitations under the License. 
1414
15- import  math 
1615import  warnings 
17- from  typing  import  List ,  Optional 
16+ from  typing  import  Optional 
1817
1918import  torch 
2019import  torch .nn  as  nn 
@@ -72,6 +71,9 @@ def __init__(self, base_layer: nn.Module, **kwargs):
7271        self ._disable_adapters  =  False 
7372        self .merged_adapters  =  []
7473
74+         # flag to enable/disable casting of input to weight dtype during forward call 
75+         self .cast_input_dtype_enabled  =  True 
76+ 
7577        base_layer  =  self .get_base_layer ()
7678        if  isinstance (base_layer , nn .Linear ):
7779            in_features , out_features  =  base_layer .in_features , base_layer .out_features 
@@ -118,7 +120,7 @@ def update_layer(
118120            requires_grad = True ,
119121        )
120122
121-         self .scaling [adapter_name ] =  randlora_alpha  /  r   /   math . sqrt ( self . num_bases ) 
123+         self .scaling [adapter_name ] =  randlora_alpha  /  r 
122124
123125        # non trainable references to randlora_A/B buffers 
124126        self .randlora_A  =  randlora_A 
@@ -153,6 +155,7 @@ def update_layer(
153155            )
154156            if  randlora_A_param .shape [0 ] <  self .r [adapter_name ]:
155157                raise  ValueError (error_tmpl .format ("randlora_A" , randlora_A_param .shape [0 ], self .r [adapter_name ]))
158+ 
156159            if  randlora_B_param .shape [- 1 ] <  self .r [adapter_name ]:
157160                raise  ValueError (error_tmpl .format ("randlora_B" , randlora_B_param .shape [- 1 ], self .r [adapter_name ]))
158161
@@ -169,9 +172,7 @@ def reset_randlora_parameters(self, adapter_name):
169172        if  adapter_name  in  self .randlora_lambda .keys ():
170173            with  torch .no_grad ():
171174                nn .init .zeros_ (self .randlora_lambda [adapter_name ])
172-                 nn .init .ones_ (self .randlora_gamma [adapter_name ]).fill_ (
173-                     1  /  max (self .randlora_gamma [adapter_name ].shape )
174-                 )
175+                 nn .init .constant_ (self .randlora_gamma [adapter_name ], 1  /  max (self .randlora_gamma [adapter_name ].shape ))
175176
176177
177178class  Linear (nn .Linear , RandLoraLayer ):
@@ -198,7 +199,7 @@ def __init__(
198199        self .update_layer (adapter_name , randlora_A , randlora_B , r , randlora_alpha , randlora_dropout , init_weights )
199200        self .is_target_conv_1d_layer  =  is_target_conv_1d_layer 
200201
201-     def  merge (self , safe_merge : bool  =  False , adapter_names : Optional [List [str ]] =  None ) ->  None :
202+     def  merge (self , safe_merge : bool  =  False , adapter_names : Optional [list [str ]] =  None ) ->  None :
202203        """ 
203204        Merge the active adapter weights into the base weights 
204205
@@ -207,7 +208,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
207208                If True, the merge operation will be performed in a copy of the original weights and check for NaNs 
208209                before merging the weights. This is useful if you want to check if the merge operation will produce 
209210                NaNs. Defaults to `False`. 
210-             adapter_names (`List [str]`, *optional*): 
211+             adapter_names (`list [str]`, *optional*): 
211212                The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults 
212213                to `None`. 
213214        """ 
@@ -219,6 +220,8 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
219220        for  active_adapter  in  adapter_names :
220221            if  active_adapter  in  self .randlora_lambda .keys ():
221222                base_layer  =  self .get_base_layer ()
223+                 orig_dtype  =  base_layer .weight .dtype 
224+ 
222225                if  safe_merge :
223226                    # Note that safe_merge will be slower than the normal merge 
224227                    # because of the copy operation. 
@@ -231,9 +234,11 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N
231234                            f"NaNs detected in the merged weights. The adapter { active_adapter }   seems to be broken" 
232235                        )
233236
234-                     base_layer .weight .data  =  orig_weights 
237+                     base_layer .weight .data  =  orig_weights . to ( orig_dtype ) 
235238                else :
236-                     base_layer .weight .data  +=  self .get_delta_weight (active_adapter )
239+                     delta_weight  =  self .get_delta_weight (active_adapter )
240+                     base_layer .weight .data  +=  delta_weight .to (orig_dtype )
241+ 
237242                self .merged_adapters .append (active_adapter )
238243
239244    def  unmerge (self ) ->  None :
@@ -242,9 +247,12 @@ def unmerge(self) -> None:
242247            return 
243248
244249        while  len (self .merged_adapters ) >  0 :
250+             base_layer  =  self .get_base_layer ()
251+             orig_dtype  =  base_layer .weight .dtype 
245252            active_adapter  =  self .merged_adapters .pop ()
246253            if  active_adapter  in  self .randlora_lambda .keys ():
247-                 self .get_base_layer ().weight .data  -=  self .get_delta_weight (active_adapter )
254+                 delta_weight  =  self .get_delta_weight (active_adapter )
255+                 base_layer .weight .data  -=  delta_weight .to (orig_dtype )
248256
249257    def  get_scaled_bases (self , adapter ) ->  tuple [torch .Tensor , torch .Tensor ]:
250258        """ 
@@ -289,7 +297,7 @@ def get_scaled_bases(self, adapter) -> tuple[torch.Tensor, torch.Tensor]:
289297        update_B  =  sliced_B .flatten (start_dim = 1 )
290298        update_A  =  UniqueBaseGrad .apply (sliced_A , randlora_lambda , randlora_gamma ).flatten (end_dim = 1 )
291299
292-         # Since update_A is applied on the smallest dimension, test whether update_A or update_B should applied first. This is done to reduce trainable parameters. 
300+         # Since update_A is applied on the smallest dimension, test whether update_A or update_B should be  applied first. This is done to reduce trainable parameters. 
293301        if  min_dim  ==  self .in_features :
294302            return  update_A , update_B 
295303        return  update_B .T , update_A .T 
0 commit comments