Skip to content

Commit 8954a0c

Browse files
flybird11111ver217
authored andcommitted
[LowLevelZero] low level zero support lora (#5153)
* low level zero support lora low level zero support lora * add checkpoint test * add checkpoint test * fix * fix * fix * fix fix fix fix * fix * fix fix fix fix fix fix fix * fix * fix fix fix fix fix fix fix * fix * test ci * git # This is a combination of 3 commits. Update low_level_zero_plugin.py Update low_level_zero_plugin.py fix fix fix * fix naming fix naming fix naming fix
1 parent 14b0d4c commit 8954a0c

File tree

8 files changed

+264
-8
lines changed

8 files changed

+264
-8
lines changed

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
import enum
12
import logging
23
import os
4+
import warnings
35
from functools import partial
46
from pathlib import Path
57
from types import MethodType
68
from typing import Callable, Dict, Iterator, List, Optional, Tuple
79

810
import torch
911
import torch.nn as nn
12+
from torch.nn import Parameter
1013
from torch.optim import Optimizer
1114
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
1215
from torch.utils._pytree import tree_map
@@ -42,6 +45,12 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
4245
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
4346

4447

48+
class OptimizerParamCheckState(enum.Enum):
49+
ORIGIN_PARAM_FINDED = 0
50+
ORIGIN_PARAM_NOT_FIND = -1
51+
LORA_PARM_EXISTED = -2
52+
53+
4554
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
4655
def __init__(self, module: nn.Module, precision: str) -> None:
4756
super().__init__(module)
@@ -209,6 +218,19 @@ def load_sharded_model(
209218
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
210219
model.update_master_params()
211220

221+
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
222+
if os.path.isfile(checkpoint):
223+
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
224+
return
225+
from peft import PeftModel
226+
227+
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
228+
peft_model = model.unwrap()
229+
assert isinstance(
230+
peft_model, PeftModel
231+
), "The model doesn't have lora adapters, please enable lora before saving."
232+
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
233+
212234

213235
class LowLevelZeroPlugin(DPPluginBase):
214236
"""
@@ -288,6 +310,7 @@ def __init__(
288310
cpu_offload=cpu_offload,
289311
master_weights=master_weights,
290312
)
313+
self.lora_enabled = False
291314
self.verbose = verbose
292315

293316
# set class name with stage, for better error message
@@ -311,6 +334,72 @@ def control_device(self) -> bool:
311334
def supported_devices(self) -> List[str]:
312335
return ["cuda", "npu"]
313336

337+
def support_lora(self) -> bool:
338+
return True
339+
340+
def enable_lora(
341+
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
342+
) -> nn.Module:
343+
from peft import PeftModel, get_peft_model
344+
345+
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
346+
self.lora_enabled = True
347+
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
348+
349+
if pretrained_dir is None:
350+
peft_model = get_peft_model(model, lora_config)
351+
else:
352+
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
353+
return peft_model
354+
355+
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
356+
origin_param_id = id(origin_param)
357+
for group_id, param_group in enumerate(optimizer.param_groups):
358+
for p in param_group["params"]:
359+
if id(p) == origin_param_id:
360+
return group_id
361+
return -1
362+
363+
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter):
364+
origin_param_id = id(origin_param)
365+
lora_param_id = id(lora_param)
366+
target_group_id = None
367+
for group_id, param_group in enumerate(optimizer.param_groups):
368+
for p in param_group["params"]:
369+
if id(p) == lora_param_id:
370+
# check if the lora parameter exists.
371+
return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED
372+
if id(p) == origin_param_id:
373+
target_group_id = group_id
374+
if target_group_id is not None:
375+
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED
376+
else:
377+
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND
378+
379+
def add_lora_params_to_optimizer(self, model, optimizer):
380+
"""add lora parameters to optimizer"""
381+
name2param = {}
382+
for name, param in model.named_parameters():
383+
name2param[name] = param
384+
385+
for name, param in name2param.items():
386+
if "lora_A" in name or "lora_B" in name:
387+
origin_key = name.replace("lora_A.", "")
388+
origin_key = origin_key.replace("lora_B.", "")
389+
origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer")
390+
origin_param = name2param[origin_key]
391+
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
392+
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
393+
warnings.warn(
394+
"Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups."
395+
)
396+
elif (
397+
check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED
398+
and group_id is not None
399+
and group_id >= 0
400+
):
401+
optimizer.param_groups[group_id]["params"].append(param)
402+
314403
def configure(
315404
self,
316405
model: nn.Module,
@@ -319,6 +408,15 @@ def configure(
319408
dataloader: Optional[DataLoader] = None,
320409
lr_scheduler: Optional[LRScheduler] = None,
321410
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
411+
if self.lora_enabled:
412+
from peft import PeftModel
413+
414+
assert isinstance(
415+
model, PeftModel
416+
), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
417+
if optimizer is not None:
418+
self.add_lora_params_to_optimizer(model, optimizer)
419+
322420
if not isinstance(model, ModelWrapper):
323421
model = LowLevelZeroModel(model, self.precision)
324422

@@ -340,8 +438,3 @@ def get_checkpoint_io(self) -> CheckpointIO:
340438
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
341439
assert isinstance(optimizer, LowLevelZeroOptimizer)
342440
return optimizer.no_sync()
343-
344-
def enable_lora(
345-
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
346-
) -> nn.Module:
347-
raise NotImplementedError

colossalai/pipeline/p2p.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,18 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
4545
return unpickle
4646

4747

48+
def check_for_nccl_backend(group):
49+
pg = group or c10d._get_default_group()
50+
# Gate PG wrapper check on Gloo availability.
51+
if c10d._GLOO_AVAILABLE:
52+
# It is not expected for PG to be wrapped many times, but support it just
53+
# in case
54+
while isinstance(pg, c10d._ProcessGroupWrapper):
55+
pg = pg.wrapped_pg
56+
57+
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
58+
59+
4860
# NOTE: FIXME: NPU DOES NOT support isend nor irecv, so broadcast is kept for future use
4961
def _broadcast_object_list(
5062
object_list: List[Any], src: int, group: ProcessGroup, device: Optional[Union[torch.device, str, int]] = None

colossalai/zero/low_level/bookkeeping/gradient_store.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ def get_working_grads_by_group_id(self, group_id: int) -> List:
8282
"""
8383

8484
grad_list = []
85+
# When using LoRa and the user sets multiple param_groups, it is possible that some param_groups have no parameters with gradients.
86+
if group_id not in self._grads_of_params.keys():
87+
return grad_list
8588
for param_grads in self._grads_of_params[group_id].values():
8689
grad_list.append(param_grads[self._working_index])
8790

requirements/requirements-test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,5 @@ flash_attn
1818
datasets
1919
pydantic
2020
ray
21-
peft
21+
peft>=0.7.1
2222
#auto-gptq now not support torch1.12

requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ sentencepiece
1717
google
1818
protobuf
1919
transformers==4.36.2
20+
peft>=0.7.1

tests/test_booster/test_plugin/test_dp_plugin_base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Iterator, List, Tuple, Union
1+
from typing import Callable, Dict, Iterator, List, Tuple, Union
22

33
import torch
44
import torch.distributed as dist
@@ -51,6 +51,12 @@ def supported_precisions(self) -> List[str]:
5151
def no_sync(self, model: nn.Module) -> Iterator[None]:
5252
pass
5353

54+
def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
55+
pass
56+
57+
def support_lora(self) -> bool:
58+
pass
59+
5460

5561
def check_dataloader_sharding():
5662
plugin = DPPluginWrapper()

tests/test_booster/test_plugin/test_low_level_zero_plugin.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
import torch.distributed as dist
5+
from peft import LoraConfig
56
from torch.optim import Adam
67

78
import colossalai
@@ -22,13 +23,17 @@
2223

2324

2425
@clear_cache_before_run()
25-
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
26+
def run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config=None) -> Optional[str]:
2627
device = get_accelerator().get_current_device()
2728
try:
2829
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=2**5)
2930
booster = Booster(plugin=plugin)
3031
model = model_fn()
3132
optimizer = Adam(model.parameters(), lr=1e-3)
33+
34+
if lora_config is not None:
35+
model = booster.enable_lora(model, lora_config=lora_config)
36+
3237
criterion = lambda x: x.mean()
3338
data = data_gen_fn()
3439

@@ -48,6 +53,7 @@ def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]:
4853

4954
except Exception as e:
5055
return repr(e)
56+
# raise e
5157

5258

5359
@parameterize("stage", [2])
@@ -91,10 +97,42 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
9197
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
9298

9399

100+
@parameterize("stage", [2])
101+
@parameterize("model_name", ["transformers_llama"])
102+
def check_low_level_zero_lora(stage, model_name, early_stop: bool = True):
103+
passed_models = []
104+
failed_info = {} # (model_name, error) pair
105+
106+
sub_model_zoo = model_zoo.get_sub_registry(model_name)
107+
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
108+
task_type = None
109+
if name == "transformers_llama_for_casual_lm":
110+
task_type = "CAUSAL_LM"
111+
if name == "transformers_llama_for_sequence_classification":
112+
task_type = "SEQ_CLS"
113+
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
114+
err = run_fn(stage, model_fn, data_gen_fn, output_transform_fn, lora_config)
115+
116+
torch.cuda.empty_cache()
117+
118+
if err is None:
119+
passed_models.append(name)
120+
else:
121+
failed_info[name] = err
122+
if early_stop:
123+
break
124+
125+
if dist.get_rank() == 0:
126+
print(f"Passed models({len(passed_models)}): {passed_models}\n\n")
127+
print(f"Failed models({len(failed_info)}): {list(failed_info.keys())}\n\n")
128+
assert len(failed_info) == 0, "\n".join([f"{k}: {v}" for k, v in failed_info.items()])
129+
130+
94131
def run_dist(rank, world_size, port, early_stop: bool = True):
95132
# init dist env
96133
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
97134
check_low_level_zero_plugin(early_stop=early_stop)
135+
check_low_level_zero_lora(early_stop=early_stop)
98136

99137

100138
@rerun_if_address_is_in_use()

0 commit comments

Comments
 (0)