2
2
3
3
import torch
4
4
import torch .nn as nn
5
+ from peft import PeftModel
5
6
from torch .nn .parallel import DistributedDataParallel as DDP
6
7
from torch .optim import Optimizer
7
8
from torch .optim .lr_scheduler import _LRScheduler as LRScheduler
@@ -166,23 +167,29 @@ def load_sharded_optimizer(
166
167
)
167
168
168
169
def save_lora_as_pretrained (
169
- self , model : Union [nn .Module , ModelWrapper ], checkpoint : str , use_safetensors : bool = False
170
+ self ,
171
+ model : Union [nn .Module , ModelWrapper ],
172
+ checkpoint : str ,
173
+ use_safetensors : bool = False ,
174
+ state_dict : Optional [dict ] = None ,
170
175
) -> None :
171
176
"""
172
177
Save the lora adapters and adapter configuration file to checkpoint directory.
173
178
"""
174
179
from peft import PeftModel
175
180
176
181
assert isinstance (model , ModelWrapper ), "Please boost the model before saving!"
182
+ peft_model = model .unwrap (unwrap_peft = False )
183
+ assert isinstance (
184
+ peft_model , PeftModel
185
+ ), "The model doesn't have lora adapters, please enable lora before saving."
186
+ if state_dict is None :
187
+ state_dict = tree_map (lambda x : x .data .cpu () if torch .is_tensor (x ) else x , peft_model .state_dict ())
177
188
if self .coordinator .is_master ():
178
- peft_model = model .unwrap ()
179
- assert isinstance (
180
- peft_model , PeftModel
181
- ), "The model doesn't have lora adapters, please enable lora before saving."
182
189
return peft_model .save_pretrained (
183
190
checkpoint ,
184
191
safe_serialization = use_safetensors ,
185
- state_dict = tree_map ( lambda x : x . data if torch . is_tensor ( x ) else x , peft_model . state_dict ()) ,
192
+ state_dict = state_dict ,
186
193
)
187
194
188
195
@@ -191,8 +198,11 @@ def __init__(self, module: nn.Module, *args, **kwargs) -> None:
191
198
super ().__init__ (module )
192
199
self .module = DDP (module , * args , ** kwargs )
193
200
194
- def unwrap (self ):
195
- return self .module .module
201
+ def unwrap (self , unwrap_peft : bool = True ) -> nn .Module :
202
+ model = self .module .module
203
+ if unwrap_peft and isinstance (model , PeftModel ):
204
+ model = model .get_base_model ()
205
+ return model
196
206
197
207
198
208
class TorchDDPPlugin (DPPluginBase ):
0 commit comments