Skip to content

Commit b10339d

Browse files
committed
fix lora ckpt save format (ColoTensor to Tensor)
1 parent 5ddad48 commit b10339d

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,8 @@ def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
290290
assert isinstance(
291291
peft_model, PeftModel
292292
), "The model doesn't have lora adapters, please enable lora before saving."
293-
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
293+
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
294+
state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()))
294295

295296

296297
class LowLevelZeroPlugin(DPPluginBase):

colossalai/booster/plugin/torch_ddp_plugin.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
22

3+
import torch
34
import torch.nn as nn
45
from torch.nn.parallel import DistributedDataParallel as DDP
56
from torch.optim import Optimizer
67
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
78
from torch.utils.data import DataLoader
9+
from torch.utils._pytree import tree_map
810

911
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
1012
from colossalai.cluster import DistCoordinator
@@ -134,7 +136,9 @@ def save_lora_as_pretrained(
134136
assert isinstance(
135137
peft_model, PeftModel
136138
), "The model doesn't have lora adapters, please enable lora before saving."
137-
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)
139+
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
140+
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x,
141+
peft_model.state_dict()))
138142

139143

140144
class TorchDDPModel(ModelWrapper):

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch.nn as nn
1212
from torch.distributed import ProcessGroup
1313
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
14+
from torch.utils._pytree import tree_map
1415

1516
from colossalai.cluster import DistCoordinator
1617
from colossalai.interface import ModelWrapper, OptimizerWrapper
@@ -956,4 +957,6 @@ def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
956957
assert isinstance(
957958
peft_model, PeftModel
958959
), "The model doesn't have lora adapters, please enable lora before saving."
959-
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
960+
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
961+
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x,
962+
peft_model.state_dict()))

0 commit comments

Comments
 (0)