Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ jobs:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors
HF_ENDPOINT: https://hf-mirror.com

- name: Collate artifact
env:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/build_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ jobs:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors
HF_ENDPOINT: https://hf-mirror.com

- name: Notify Lark
id: message-preparation
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/compatiblity_test_on_dispatch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,4 @@ jobs:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors
HF_ENDPOINT: https://hf-mirror.com
1 change: 1 addition & 0 deletions .github/workflows/compatiblity_test_on_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,4 @@ jobs:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors
HF_ENDPOINT: https://hf-mirror.com
1 change: 1 addition & 0 deletions .github/workflows/compatiblity_test_on_schedule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ jobs:
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
LLAMA_PATH: /data/scratch/llama-tiny
MOE_TENSOR_PATH: /data/scratch/moe_tensors
HF_ENDPOINT: https://hf-mirror.com

- name: Notify Lark
id: message-preparation
Expand Down
15 changes: 9 additions & 6 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import torch
import torch.distributed as dist
from peft import PeftModel
from torch import Tensor, inf
from torch.distributed import ProcessGroup, get_world_size
from torch.nn import Module, SyncBatchNorm
Expand Down Expand Up @@ -219,11 +220,13 @@ def forward(self, *args, **kwargs):
with self._hook_context():
return super().forward(*args, **kwargs)

def unwrap(self):
module = super().unwrap()
if isinstance(module, DDP):
module = module.module
return module
def unwrap(self, unwrap_peft: bool = True):
model = self.module
if isinstance(model, DDP):
model = model.module
if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model()
return model

def _force_wait_all_gather(self):
for p in self.module.parameters():
Expand Down Expand Up @@ -1509,7 +1512,7 @@ def enable_lora(
from peft import PeftModel, get_peft_model

assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
assert self.pp_size == 1 and self.tp_size == 1
assert self.tp_size == 1
self.lora_enabled = True
self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])

Expand Down
19 changes: 3 additions & 16 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,23 +359,10 @@ def save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async=use_async
)

def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
if os.path.isfile(checkpoint):
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
return
from peft import PeftModel

assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict: Optional[dict] = None):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)
super().save_lora_as_pretrained(model, checkpoint, use_safetensors, state_dict=state_dict)


class LowLevelZeroPlugin(DPPluginBase):
Expand Down
26 changes: 18 additions & 8 deletions colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.nn as nn
from peft import PeftModel
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
Expand Down Expand Up @@ -166,23 +167,29 @@ def load_sharded_optimizer(
)

def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
use_safetensors: bool = False,
state_dict: Optional[dict] = None,
) -> None:
"""
Save the lora adapters and adapter configuration file to checkpoint directory.
"""
from peft import PeftModel

assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
peft_model = model.unwrap(unwrap_peft=False)
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
if state_dict is None:
state_dict = tree_map(lambda x: x.data.cpu() if torch.is_tensor(x) else x, peft_model.state_dict())
if self.coordinator.is_master():
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
state_dict=state_dict,
)


Expand All @@ -191,8 +198,11 @@ def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = DDP(module, *args, **kwargs)

def unwrap(self):
return self.module.module
def unwrap(self, unwrap_peft: bool = True) -> nn.Module:
model = self.module.module
if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model()
return model


class TorchDDPPlugin(DPPluginBase):
Expand Down
3 changes: 0 additions & 3 deletions colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,6 @@ def __init__(self, module: nn.Module, *args, **kwargs) -> None:
super().__init__(module)
self.module = FSDP(module, *args, **kwargs)

def unwrap(self):
return self.module


class FSDPOptimizerWrapper(OptimizerWrapper):
def __init__(self, optimizer: Optimizer, model: nn.Module):
Expand Down
7 changes: 6 additions & 1 deletion colossalai/checkpoint_io/checkpoint_io_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,11 @@ def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):

@abstractmethod
def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
self,
model: Union[nn.Module, ModelWrapper],
checkpoint: str,
use_safetensors: bool = False,
state_dict: Optional[dict] = None,
) -> None:
"""
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
Expand All @@ -446,4 +450,5 @@ def save_lora_as_pretrained(
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
checkpoint (str): Path to the checkpoint directory. It must be a local path.
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
state_dict (Optional[dict], optional): The state dict to save. Defaults to None.
"""
4 changes: 3 additions & 1 deletion colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,5 +308,7 @@ def load_sharded_model(
)
)

def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None:
def save_lora_as_pretrained(
self, model: nn.Module, checkpoint: str, use_safetensors: bool = False, state_dict: Optional[dict] = None
) -> None:
raise NotImplementedError
25 changes: 18 additions & 7 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
async_save_state_dict_shards,
create_pinned_state_dict,
gather_distributed_param,
gather_state_dict_fast,
get_lora_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
Expand Down Expand Up @@ -1139,20 +1141,29 @@ def shard_from_complete_optimizer_state(

return state_

def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict: Optional[dict] = None):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
from peft import PeftModel

assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap()
peft_model = model.unwrap(unwrap_peft=False)
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
)
if state_dict is None:
state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict())
if self.pp_size > 1:
lora_state_dict = get_lora_state_dict(peft_model, state_dict)
gathered_lora_state_dict = gather_state_dict_fast(lora_state_dict, self.pp_group, device="cpu")
if self.pp_rank == 0:
state_dict.update(gathered_lora_state_dict)
state_dict = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)
if self.coordinator.is_master():
return peft_model.save_pretrained(
checkpoint,
safe_serialization=use_safetensors,
state_dict=state_dict,
)
26 changes: 26 additions & 0 deletions colossalai/checkpoint_io/moe_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import get_global_rank
from torch.utils._pytree import tree_map

from colossalai.checkpoint_io import CheckpointIndexFile
from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
from colossalai.checkpoint_io.index_file import CheckpointIndexFile
from colossalai.checkpoint_io.utils import (
StateDictSharder,
gather_distributed_param,
gather_state_dict_fast,
get_lora_state_dict,
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
Expand Down Expand Up @@ -889,3 +892,26 @@ def _get_param_id_from_optimizer_param(
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
dist.barrier()

def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict=None):
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
from peft import PeftModel

assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model._force_wait_all_gather()
peft_model = model.unwrap(unwrap_peft=False)
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
if state_dict is None:
state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict())
if self.ep_size > 1:
lora_state_dict = get_lora_state_dict(peft_model, state_dict)
moe_params = set(n for n, p in peft_model.named_parameters() if is_moe_tensor(p))
expert_state_dict = {n: p for n, p in lora_state_dict.items() if n in moe_params}
gathered_expert_state_dict = gather_state_dict_fast(expert_state_dict, self.ep_group)
if self.ep_rank == 0:
state_dict.update(gathered_expert_state_dict)
return super().save_lora_as_pretrained(model, checkpoint, use_safetensors, state_dict)
Loading