|
8 | 8 | from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
9 | 9 | from torch.utils.data import DataLoader
|
10 | 10 |
|
| 11 | +SUPPORT_PEFT = False |
| 12 | +try: |
| 13 | + import peft |
| 14 | + |
| 15 | + SUPPORT_PEFT = True |
| 16 | +except ImportError: |
| 17 | + pass |
| 18 | + |
11 | 19 | import colossalai.interface.pretrained as pretrained_utils
|
12 | 20 | from colossalai.checkpoint_io import GeneralCheckpointIO
|
13 | 21 | from colossalai.interface import ModelWrapper, OptimizerWrapper
|
@@ -221,6 +229,38 @@ def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -
|
221 | 229 | assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync."
|
222 | 230 | return self.plugin.no_sync(model, optimizer)
|
223 | 231 |
|
| 232 | + def enable_lora( |
| 233 | + self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None |
| 234 | + ) -> nn.Module: |
| 235 | + """ |
| 236 | + Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory. |
| 237 | + Lora in ColossalAI is implemented using Huggingface peft library, so the arguments for Lora configuration are same as those of peft. |
| 238 | +
|
| 239 | + Args: |
| 240 | + model (nn.Module): The model to be appended with LoRA modules. |
| 241 | + pretrained_dir(str, optional): The path to the pretrained directory, can be a local directory |
| 242 | + or model_id of a PEFT configuration hosted inside a model repo on the Hugging Face Hub. |
| 243 | + When set to None, create new lora configs and weights for the model using the passed in lora_config. Defaults to None. |
| 244 | + lora_config: (peft.LoraConfig, optional): Passed in LoraConfig for peft. Defaults to None. |
| 245 | + """ |
| 246 | + if not SUPPORT_PEFT: |
| 247 | + raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!") |
| 248 | + |
| 249 | + assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided." |
| 250 | + assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora." |
| 251 | + if pretrained_dir is None: |
| 252 | + assert ( |
| 253 | + lora_config is not None |
| 254 | + ), "Please provide configuration for Lora when pretrained directory path isn't passed in." |
| 255 | + assert isinstance( |
| 256 | + lora_config, peft.LoraConfig |
| 257 | + ), "The passed in configuration should be an instance of peft.LoraConfig." |
| 258 | + if lora_config is None: |
| 259 | + assert ( |
| 260 | + pretrained_dir is not None |
| 261 | + ), "Please provide pretrained directory path if not passing in lora configuration." |
| 262 | + return self.plugin.enable_lora(model, pretrained_dir, lora_config) |
| 263 | + |
224 | 264 | def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
|
225 | 265 | """Load model from checkpoint.
|
226 | 266 |
|
@@ -323,3 +363,20 @@ def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
|
323 | 363 | checkpoint (str): Path to the checkpoint. It must be a local file path.
|
324 | 364 | """
|
325 | 365 | self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
|
| 366 | + |
| 367 | + def save_lora_as_pretrained( |
| 368 | + self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False |
| 369 | + ) -> None: |
| 370 | + """ |
| 371 | + Save the lora adapters and adapter configuration file to a pretrained checkpoint directory. |
| 372 | +
|
| 373 | + Args: |
| 374 | + model (Union[nn.Module, ModelWrapper]): A model boosted by Booster. |
| 375 | + checkpoint (str): Path to the checkpoint directory. It must be a local path. |
| 376 | + use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False. |
| 377 | + """ |
| 378 | + if not SUPPORT_PEFT: |
| 379 | + raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!") |
| 380 | + assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided." |
| 381 | + assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora." |
| 382 | + self.checkpoint_io.save_lora_as_pretrained(model, checkpoint, use_safetensors) |
0 commit comments