1
+ import enum
1
2
import logging
2
3
import os
4
+ import warnings
3
5
from functools import partial
4
6
from pathlib import Path
5
7
from types import MethodType
6
8
from typing import Callable , Dict , Iterator , List , Optional , Tuple
7
9
8
10
import torch
9
11
import torch .nn as nn
12
+ from torch .nn import Parameter
10
13
from torch .optim import Optimizer
11
14
from torch .optim .lr_scheduler import _LRScheduler as LRScheduler
12
15
from torch .utils ._pytree import tree_map
@@ -42,6 +45,12 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
42
45
SUPPORTED_PRECISION = ["fp16" , "bf16" , "fp32" ]
43
46
44
47
48
+ class OptimizerParamCheckState (enum .Enum ):
49
+ ORIGIN_PARAM_FINDED = 0
50
+ ORIGIN_PARAM_NOT_FIND = - 1
51
+ LORA_PARM_EXISTED = - 2
52
+
53
+
45
54
class LowLevelZeroModel (ModelWrapper , AMPModelMixin ):
46
55
def __init__ (self , module : nn .Module , precision : str ) -> None :
47
56
super ().__init__ (module )
@@ -209,6 +218,19 @@ def load_sharded_model(
209
218
super ().load_sharded_model (model , checkpoint_index_file , strict , use_safetensors , load_sub_module )
210
219
model .update_master_params ()
211
220
221
+ def save_lora_as_pretrained (self , model , checkpoint , use_safetensors ):
222
+ if os .path .isfile (checkpoint ):
223
+ logging .error (f"Provided path ({ checkpoint } ) should be a directory, not a file" )
224
+ return
225
+ from peft import PeftModel
226
+
227
+ assert isinstance (model , ModelWrapper ), "Please boost the model before saving!"
228
+ peft_model = model .unwrap ()
229
+ assert isinstance (
230
+ peft_model , PeftModel
231
+ ), "The model doesn't have lora adapters, please enable lora before saving."
232
+ return peft_model .save_pretrained (checkpoint , safe_serialization = use_safetensors )
233
+
212
234
213
235
class LowLevelZeroPlugin (DPPluginBase ):
214
236
"""
@@ -288,6 +310,7 @@ def __init__(
288
310
cpu_offload = cpu_offload ,
289
311
master_weights = master_weights ,
290
312
)
313
+ self .lora_enabled = False
291
314
self .verbose = verbose
292
315
293
316
# set class name with stage, for better error message
@@ -311,6 +334,72 @@ def control_device(self) -> bool:
311
334
def supported_devices (self ) -> List [str ]:
312
335
return ["cuda" , "npu" ]
313
336
337
+ def support_lora (self ) -> bool :
338
+ return True
339
+
340
+ def enable_lora (
341
+ self , model : nn .Module , pretrained_dir : Optional [str ] = None , lora_config : Optional [Dict ] = None
342
+ ) -> nn .Module :
343
+ from peft import PeftModel , get_peft_model
344
+
345
+ assert not isinstance (model , LowLevelZeroModel ), "Lora should be enabled before boosting the model."
346
+ self .lora_enabled = True
347
+ warnings .warn ("You have enabled LoRa training. Please check the hyperparameters such as lr" )
348
+
349
+ if pretrained_dir is None :
350
+ peft_model = get_peft_model (model , lora_config )
351
+ else :
352
+ peft_model = PeftModel .from_pretrained (model , pretrained_dir , is_trainable = True )
353
+ return peft_model
354
+
355
+ def get_param_group_id (self , optimizer : Optimizer , origin_param : Parameter ):
356
+ origin_param_id = id (origin_param )
357
+ for group_id , param_group in enumerate (optimizer .param_groups ):
358
+ for p in param_group ["params" ]:
359
+ if id (p ) == origin_param_id :
360
+ return group_id
361
+ return - 1
362
+
363
+ def get_param_group_id (self , optimizer : Optimizer , origin_param : Parameter , lora_param : Parameter ):
364
+ origin_param_id = id (origin_param )
365
+ lora_param_id = id (lora_param )
366
+ target_group_id = None
367
+ for group_id , param_group in enumerate (optimizer .param_groups ):
368
+ for p in param_group ["params" ]:
369
+ if id (p ) == lora_param_id :
370
+ # check if the lora parameter exists.
371
+ return target_group_id , OptimizerParamCheckState .LORA_PARM_EXISTED
372
+ if id (p ) == origin_param_id :
373
+ target_group_id = group_id
374
+ if target_group_id is not None :
375
+ return target_group_id , OptimizerParamCheckState .ORIGIN_PARAM_FINDED
376
+ else :
377
+ return target_group_id , OptimizerParamCheckState .ORIGIN_PARAM_NOT_FIND
378
+
379
+ def add_lora_params_to_optimizer (self , model , optimizer ):
380
+ """add lora parameters to optimizer"""
381
+ name2param = {}
382
+ for name , param in model .named_parameters ():
383
+ name2param [name ] = param
384
+
385
+ for name , param in name2param .items ():
386
+ if "lora_A" in name or "lora_B" in name :
387
+ origin_key = name .replace ("lora_A." , "" )
388
+ origin_key = origin_key .replace ("lora_B." , "" )
389
+ origin_key = origin_key .replace (f"{ model .active_adapter } " , "base_layer" )
390
+ origin_param = name2param [origin_key ]
391
+ group_id , check_state = self .get_param_group_id (optimizer , origin_param , param )
392
+ if check_state == OptimizerParamCheckState .ORIGIN_PARAM_NOT_FIND :
393
+ warnings .warn (
394
+ "Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
395
+ )
396
+ elif (
397
+ check_state == OptimizerParamCheckState .ORIGIN_PARAM_FINDED
398
+ and group_id is not None
399
+ and group_id >= 0
400
+ ):
401
+ optimizer .param_groups [group_id ]["params" ].append (param )
402
+
314
403
def configure (
315
404
self ,
316
405
model : nn .Module ,
@@ -319,6 +408,15 @@ def configure(
319
408
dataloader : Optional [DataLoader ] = None ,
320
409
lr_scheduler : Optional [LRScheduler ] = None ,
321
410
) -> Tuple [nn .Module , OptimizerWrapper , Callable , DataLoader , LRScheduler ]:
411
+ if self .lora_enabled :
412
+ from peft import PeftModel
413
+
414
+ assert isinstance (
415
+ model , PeftModel
416
+ ), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
417
+ if optimizer is not None :
418
+ self .add_lora_params_to_optimizer (model , optimizer )
419
+
322
420
if not isinstance (model , ModelWrapper ):
323
421
model = LowLevelZeroModel (model , self .precision )
324
422
@@ -340,8 +438,3 @@ def get_checkpoint_io(self) -> CheckpointIO:
340
438
def no_sync (self , model : nn .Module , optimizer : OptimizerWrapper ) -> Iterator [None ]:
341
439
assert isinstance (optimizer , LowLevelZeroOptimizer )
342
440
return optimizer .no_sync ()
343
-
344
- def enable_lora (
345
- self , model : nn .Module , pretrained_dir : Optional [str ] = None , lora_config : Optional [Dict ] = None
346
- ) -> nn .Module :
347
- raise NotImplementedError
0 commit comments