Skip to content

Commit 014837e

Browse files
authored
[shardformer] support pipeline for deepseek v3 and optimize lora save (#6188)
* [shardformer] support pipeline for deepseek v3 * [checkpointio] fix lora save * [devops] update ci env * [booster] optimize lora * fix test * fix test
1 parent ec73f1b commit 014837e

21 files changed

+478
-91
lines changed

.github/workflows/build_on_pr.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ jobs:
166166
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
167167
LLAMA_PATH: /data/scratch/llama-tiny
168168
MOE_TENSOR_PATH: /data/scratch/moe_tensors
169+
HF_ENDPOINT: https://hf-mirror.com
169170

170171
- name: Collate artifact
171172
env:

.github/workflows/build_on_schedule.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ jobs:
7070
LD_LIBRARY_PATH: /github/home/.tensornvme/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
7171
LLAMA_PATH: /data/scratch/llama-tiny
7272
MOE_TENSOR_PATH: /data/scratch/moe_tensors
73+
HF_ENDPOINT: https://hf-mirror.com
7374

7475
- name: Notify Lark
7576
id: message-preparation

.github/workflows/compatiblity_test_on_dispatch.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,4 @@ jobs:
7979
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
8080
LLAMA_PATH: /data/scratch/llama-tiny
8181
MOE_TENSOR_PATH: /data/scratch/moe_tensors
82+
HF_ENDPOINT: https://hf-mirror.com

.github/workflows/compatiblity_test_on_pr.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,4 @@ jobs:
7373
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
7474
LLAMA_PATH: /data/scratch/llama-tiny
7575
MOE_TENSOR_PATH: /data/scratch/moe_tensors
76+
HF_ENDPOINT: https://hf-mirror.com

.github/workflows/compatiblity_test_on_schedule.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ jobs:
6767
LD_LIBRARY_PATH: /github/home/.tensornvme/lib
6868
LLAMA_PATH: /data/scratch/llama-tiny
6969
MOE_TENSOR_PATH: /data/scratch/moe_tensors
70+
HF_ENDPOINT: https://hf-mirror.com
7071

7172
- name: Notify Lark
7273
id: message-preparation

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111
import torch
1212
import torch.distributed as dist
13+
from peft import PeftModel
1314
from torch import Tensor, inf
1415
from torch.distributed import ProcessGroup, get_world_size
1516
from torch.nn import Module, SyncBatchNorm
@@ -219,11 +220,13 @@ def forward(self, *args, **kwargs):
219220
with self._hook_context():
220221
return super().forward(*args, **kwargs)
221222

222-
def unwrap(self):
223-
module = super().unwrap()
224-
if isinstance(module, DDP):
225-
module = module.module
226-
return module
223+
def unwrap(self, unwrap_peft: bool = True):
224+
model = self.module
225+
if isinstance(model, DDP):
226+
model = model.module
227+
if unwrap_peft and isinstance(model, PeftModel):
228+
model = model.get_base_model()
229+
return model
227230

228231
def _force_wait_all_gather(self):
229232
for p in self.module.parameters():
@@ -1509,7 +1512,7 @@ def enable_lora(
15091512
from peft import PeftModel, get_peft_model
15101513

15111514
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
1512-
assert self.pp_size == 1 and self.tp_size == 1
1515+
assert self.tp_size == 1
15131516
self.lora_enabled = True
15141517
self.logger.warning("You have enabled LoRa training. Please check the hyperparameters such as lr", ranks=[0])
15151518

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -359,23 +359,10 @@ def save_sharded_model(
359359
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async=use_async
360360
)
361361

362-
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
363-
if os.path.isfile(checkpoint):
364-
self.logger.error(f"Provided path ({checkpoint}) should be a directory, not a file", ranks=[0])
365-
return
366-
from peft import PeftModel
367-
368-
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
362+
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors, state_dict: Optional[dict] = None):
363+
assert isinstance(model, LowLevelZeroModel), "Please boost the model before saving!"
369364
model._force_wait_all_gather()
370-
peft_model = model.unwrap()
371-
assert isinstance(
372-
peft_model, PeftModel
373-
), "The model doesn't have lora adapters, please enable lora before saving."
374-
return peft_model.save_pretrained(
375-
checkpoint,
376-
safe_serialization=use_safetensors,
377-
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
378-
)
365+
super().save_lora_as_pretrained(model, checkpoint, use_safetensors, state_dict=state_dict)
379366

380367

381368
class LowLevelZeroPlugin(DPPluginBase):

colossalai/booster/plugin/torch_ddp_plugin.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
import torch.nn as nn
5+
from peft import PeftModel
56
from torch.nn.parallel import DistributedDataParallel as DDP
67
from torch.optim import Optimizer
78
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
@@ -166,23 +167,29 @@ def load_sharded_optimizer(
166167
)
167168

168169
def save_lora_as_pretrained(
169-
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
170+
self,
171+
model: Union[nn.Module, ModelWrapper],
172+
checkpoint: str,
173+
use_safetensors: bool = False,
174+
state_dict: Optional[dict] = None,
170175
) -> None:
171176
"""
172177
Save the lora adapters and adapter configuration file to checkpoint directory.
173178
"""
174179
from peft import PeftModel
175180

176181
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
182+
peft_model = model.unwrap(unwrap_peft=False)
183+
assert isinstance(
184+
peft_model, PeftModel
185+
), "The model doesn't have lora adapters, please enable lora before saving."
186+
if state_dict is None:
187+
state_dict = tree_map(lambda x: x.data.cpu() if torch.is_tensor(x) else x, peft_model.state_dict())
177188
if self.coordinator.is_master():
178-
peft_model = model.unwrap()
179-
assert isinstance(
180-
peft_model, PeftModel
181-
), "The model doesn't have lora adapters, please enable lora before saving."
182189
return peft_model.save_pretrained(
183190
checkpoint,
184191
safe_serialization=use_safetensors,
185-
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
192+
state_dict=state_dict,
186193
)
187194

188195

@@ -191,8 +198,11 @@ def __init__(self, module: nn.Module, *args, **kwargs) -> None:
191198
super().__init__(module)
192199
self.module = DDP(module, *args, **kwargs)
193200

194-
def unwrap(self):
195-
return self.module.module
201+
def unwrap(self, unwrap_peft: bool = True) -> nn.Module:
202+
model = self.module.module
203+
if unwrap_peft and isinstance(model, PeftModel):
204+
model = model.get_base_model()
205+
return model
196206

197207

198208
class TorchDDPPlugin(DPPluginBase):

colossalai/booster/plugin/torch_fsdp_plugin.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,9 +437,6 @@ def __init__(self, module: nn.Module, *args, **kwargs) -> None:
437437
super().__init__(module)
438438
self.module = FSDP(module, *args, **kwargs)
439439

440-
def unwrap(self):
441-
return self.module
442-
443440

444441
class FSDPOptimizerWrapper(OptimizerWrapper):
445442
def __init__(self, optimizer: Optimizer, model: nn.Module):

colossalai/checkpoint_io/checkpoint_io_base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,11 @@ def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
437437

438438
@abstractmethod
439439
def save_lora_as_pretrained(
440-
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
440+
self,
441+
model: Union[nn.Module, ModelWrapper],
442+
checkpoint: str,
443+
use_safetensors: bool = False,
444+
state_dict: Optional[dict] = None,
441445
) -> None:
442446
"""
443447
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
@@ -446,4 +450,5 @@ def save_lora_as_pretrained(
446450
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
447451
checkpoint (str): Path to the checkpoint directory. It must be a local path.
448452
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
453+
state_dict (Optional[dict], optional): The state dict to save. Defaults to None.
449454
"""

0 commit comments

Comments
 (0)