Skip to content

Commit 14b0d4c

Browse files
Baizhou Zhangver217
authored andcommitted
[lora] add lora APIs for booster, support lora for TorchDDP (#4981)
* add apis and peft requirement * add liscense and implement apis * add checkpointio apis * add torchddp fwd_bwd test * add support_lora methods * add checkpointio test and debug * delete unneeded codes * remove peft from LICENSE * add concrete methods for enable_lora * simplify enable_lora api * fix requirements
1 parent c1594e4 commit 14b0d4c

File tree

11 files changed

+265
-7
lines changed

11 files changed

+265
-7
lines changed

colossalai/booster/booster.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@
88
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
99
from torch.utils.data import DataLoader
1010

11+
SUPPORT_PEFT = False
12+
try:
13+
import peft
14+
15+
SUPPORT_PEFT = True
16+
except ImportError:
17+
pass
18+
1119
import colossalai.interface.pretrained as pretrained_utils
1220
from colossalai.checkpoint_io import GeneralCheckpointIO
1321
from colossalai.interface import ModelWrapper, OptimizerWrapper
@@ -221,6 +229,38 @@ def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -
221229
assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync."
222230
return self.plugin.no_sync(model, optimizer)
223231

232+
def enable_lora(
233+
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None
234+
) -> nn.Module:
235+
"""
236+
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
237+
Lora in ColossalAI is implemented using Huggingface peft library, so the arguments for Lora configuration are same as those of peft.
238+
239+
Args:
240+
model (nn.Module): The model to be appended with LoRA modules.
241+
pretrained_dir(str, optional): The path to the pretrained directory, can be a local directory
242+
or model_id of a PEFT configuration hosted inside a model repo on the Hugging Face Hub.
243+
When set to None, create new lora configs and weights for the model using the passed in lora_config. Defaults to None.
244+
lora_config: (peft.LoraConfig, optional): Passed in LoraConfig for peft. Defaults to None.
245+
"""
246+
if not SUPPORT_PEFT:
247+
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
248+
249+
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
250+
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
251+
if pretrained_dir is None:
252+
assert (
253+
lora_config is not None
254+
), "Please provide configuration for Lora when pretrained directory path isn't passed in."
255+
assert isinstance(
256+
lora_config, peft.LoraConfig
257+
), "The passed in configuration should be an instance of peft.LoraConfig."
258+
if lora_config is None:
259+
assert (
260+
pretrained_dir is not None
261+
), "Please provide pretrained directory path if not passing in lora configuration."
262+
return self.plugin.enable_lora(model, pretrained_dir, lora_config)
263+
224264
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
225265
"""Load model from checkpoint.
226266
@@ -323,3 +363,20 @@ def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
323363
checkpoint (str): Path to the checkpoint. It must be a local file path.
324364
"""
325365
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
366+
367+
def save_lora_as_pretrained(
368+
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
369+
) -> None:
370+
"""
371+
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
372+
373+
Args:
374+
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
375+
checkpoint (str): Path to the checkpoint directory. It must be a local path.
376+
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
377+
"""
378+
if not SUPPORT_PEFT:
379+
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
380+
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
381+
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
382+
self.checkpoint_io.save_lora_as_pretrained(model, checkpoint, use_safetensors)

colossalai/booster/plugin/gemini_plugin.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import random
55
from pathlib import Path
6-
from typing import Callable, Iterator, List, Optional, Tuple
6+
from typing import Callable, Dict, Iterator, List, Optional, Tuple
77

88
import numpy as np
99
import torch
@@ -444,6 +444,9 @@ def __del__(self):
444444
def support_no_sync(self) -> bool:
445445
return False
446446

447+
def support_lora(self) -> bool:
448+
return False
449+
447450
def control_precision(self) -> bool:
448451
return True
449452

@@ -573,3 +576,8 @@ def get_checkpoint_io(self) -> CheckpointIO:
573576

574577
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
575578
raise NotImplementedError
579+
580+
def enable_lora(
581+
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
582+
) -> nn.Module:
583+
raise NotImplementedError

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from contextlib import contextmanager
55
from functools import partial
66
from types import MethodType
7-
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
7+
from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union
88

99
import numpy as np
1010
import torch
@@ -1156,6 +1156,9 @@ def control_precision(self) -> bool:
11561156
def support_no_sync(self) -> bool:
11571157
return True
11581158

1159+
def support_lora(self) -> bool:
1160+
return False
1161+
11591162
def control_checkpoint_io(self) -> bool:
11601163
return True
11611164

@@ -1356,3 +1359,8 @@ def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
13561359
self.zero_stage != 2
13571360
), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed."
13581361
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
1362+
1363+
def enable_lora(
1364+
self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
1365+
) -> Module:
1366+
raise NotImplementedError

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from functools import partial
44
from pathlib import Path
55
from types import MethodType
6-
from typing import Callable, Iterator, List, Optional, Tuple
6+
from typing import Callable, Dict, Iterator, List, Optional, Tuple
77

88
import torch
99
import torch.nn as nn
@@ -296,6 +296,9 @@ def __init__(
296296
def support_no_sync(self) -> bool:
297297
return self.stage == 1
298298

299+
def support_lora(self) -> bool:
300+
return False
301+
299302
def control_precision(self) -> bool:
300303
return True
301304

@@ -337,3 +340,8 @@ def get_checkpoint_io(self) -> CheckpointIO:
337340
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
338341
assert isinstance(optimizer, LowLevelZeroOptimizer)
339342
return optimizer.no_sync()
343+
344+
def enable_lora(
345+
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
346+
) -> nn.Module:
347+
raise NotImplementedError

colossalai/booster/plugin/plugin_base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import Callable, Iterator, List, Optional, Tuple
2+
from typing import Callable, Dict, Iterator, List, Optional, Tuple
33

44
import torch.nn as nn
55
from torch.optim import Optimizer
@@ -33,6 +33,10 @@ def control_device(self) -> bool:
3333
def support_no_sync(self) -> bool:
3434
pass
3535

36+
@abstractmethod
37+
def support_lora(self) -> bool:
38+
pass
39+
3640
@abstractmethod
3741
def configure(
3842
self,
@@ -63,6 +67,12 @@ def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[Non
6367
Context manager to disable gradient synchronization.
6468
"""
6569

70+
@abstractmethod
71+
def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
72+
"""
73+
Add LoRA modules to the model passed in. Should only be called in booster.enable_lora().
74+
"""
75+
6676
@abstractmethod
6777
def prepare_dataloader(
6878
self,

colossalai/booster/plugin/torch_ddp_plugin.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Iterator, List, Optional, Tuple
1+
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
22

33
import torch.nn as nn
44
from torch.nn.parallel import DistributedDataParallel as DDP
@@ -116,6 +116,22 @@ def load_sharded_optimizer(
116116
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
117117
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
118118

119+
def save_lora_as_pretrained(
120+
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
121+
) -> None:
122+
"""
123+
Save the lora adapters and adapter configuration file to checkpoint directory.
124+
"""
125+
from peft import PeftModel
126+
127+
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
128+
if self.coordinator.is_master():
129+
peft_model = model.unwrap()
130+
assert isinstance(
131+
peft_model, PeftModel
132+
), "The model doesn't have lora adapters, please enable lora before saving."
133+
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)
134+
119135

120136
class TorchDDPModel(ModelWrapper):
121137
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
@@ -173,6 +189,9 @@ def __init__(
173189
def support_no_sync(self) -> bool:
174190
return True
175191

192+
def support_lora(self) -> bool:
193+
return True
194+
176195
def control_precision(self) -> bool:
177196
return False
178197

@@ -216,3 +235,14 @@ def get_checkpoint_io(self) -> CheckpointIO:
216235
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
217236
assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin."
218237
return model.module.no_sync()
238+
239+
def enable_lora(
240+
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
241+
) -> nn.Module:
242+
from peft import PeftModel, get_peft_model
243+
244+
assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model."
245+
if pretrained_dir is None:
246+
return get_peft_model(model, lora_config)
247+
else:
248+
return PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)

colossalai/booster/plugin/torch_fsdp_plugin.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import warnings
44
from pathlib import Path
5-
from typing import Callable, Iterable, Iterator, List, Optional, Tuple
5+
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
66

77
import torch
88
import torch.nn as nn
@@ -318,6 +318,9 @@ def __init__(
318318
def support_no_sync(self) -> bool:
319319
return False
320320

321+
def support_lora(self) -> bool:
322+
return False
323+
321324
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
322325
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
323326

@@ -361,3 +364,8 @@ def control_checkpoint_io(self) -> bool:
361364

362365
def get_checkpoint_io(self) -> CheckpointIO:
363366
return TorchFSDPCheckpointIO()
367+
368+
def enable_lora(
369+
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
370+
) -> nn.Module:
371+
raise NotImplementedError

colossalai/checkpoint_io/checkpoint_io_base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,20 @@ def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
335335
"""
336336
state_dict = torch.load(checkpoint)
337337
lr_scheduler.load_state_dict(state_dict)
338+
339+
# ================================================================================
340+
# Abstract method for lora saving implementation.
341+
# ================================================================================
342+
343+
@abstractmethod
344+
def save_lora_as_pretrained(
345+
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
346+
) -> None:
347+
"""
348+
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
349+
350+
Args:
351+
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
352+
checkpoint (str): Path to the checkpoint directory. It must be a local path.
353+
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
354+
"""

colossalai/checkpoint_io/general_checkpoint_io.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,6 @@ def load_sharded_model(
228228
self.__class__.__name__, "\n\t".join(error_msgs)
229229
)
230230
)
231+
232+
def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None:
233+
raise NotImplementedError

requirements/requirements-test.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ git+https://github.com/hpcaitech/pytest-testmon
55
torchvision
66
timm
77
titans
8-
torchaudio
8+
torchaudio>=0.13.1
99
torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package is updated every day. We fix the version to a specific date to avoid breaking changes.
1010
torchrec==0.2.0
1111
contexttimer
@@ -18,4 +18,5 @@ flash_attn
1818
datasets
1919
pydantic
2020
ray
21+
peft
2122
#auto-gptq now not support torch1.12

0 commit comments

Comments
 (0)