@@ -1906,6 +1906,7 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):
19061906
19071907 for name , module in transformer .named_modules ():
19081908 if isinstance (module , torch .nn .Linear ) and name in module_names :
1909+ module_weight = module .weight .data
19091910 module_bias = module .bias .data if module .bias is not None else None
19101911 bias = module_bias is not None
19111912
@@ -1919,6 +1920,7 @@ def unload_lora_weights(self, reset_to_overwritten_params=False):
19191920 in_features ,
19201921 out_features ,
19211922 bias = bias ,
1923+ dtype = module_weight .dtype ,
19221924 )
19231925
19241926 tmp_state_dict = {"weight" : current_param_weight }
@@ -2021,12 +2023,16 @@ def _maybe_expand_transformer_param_shape_or_error_(
20212023 parent_module = transformer .get_submodule (parent_module_name )
20222024
20232025 with torch .device ("meta" ):
2024- expanded_module = torch .nn .Linear (in_features , out_features , bias = bias )
2026+ expanded_module = torch .nn .Linear (
2027+ in_features , out_features , bias = bias , dtype = module_weight .dtype
2028+ )
20252029 # Only weights are expanded and biases are not. This is because only the input dimensions
20262030 # are changed while the output dimensions remain the same. The shape of the weight tensor
20272031 # is (out_features, in_features), while the shape of bias tensor is (out_features,), which
20282032 # explains the reason why only weights are expanded.
2029- new_weight = torch .zeros_like (expanded_module .weight .data )
2033+ new_weight = torch .zeros_like (
2034+ expanded_module .weight .data , device = module_weight .device , dtype = module_weight .dtype
2035+ )
20302036 slices = tuple (slice (0 , dim ) for dim in module_weight_shape )
20312037 new_weight [slices ] = module_weight
20322038 tmp_state_dict = {"weight" : new_weight }
0 commit comments