Skip to content

Commit 6d6cafa

Browse files
committed
pre-commit fix
1 parent b10339d commit 6d6cafa

File tree

3 files changed

+16
-9
lines changed

3 files changed

+16
-9
lines changed

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,11 @@ def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
290290
assert isinstance(
291291
peft_model, PeftModel
292292
), "The model doesn't have lora adapters, please enable lora before saving."
293-
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
294-
state_dict = tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()))
293+
return peft_model.save_pretrained(
294+
checkpoint,
295+
safe_serialization=use_safetensors,
296+
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
297+
)
295298

296299

297300
class LowLevelZeroPlugin(DPPluginBase):

colossalai/booster/plugin/torch_ddp_plugin.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from torch.nn.parallel import DistributedDataParallel as DDP
66
from torch.optim import Optimizer
77
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
8-
from torch.utils.data import DataLoader
98
from torch.utils._pytree import tree_map
9+
from torch.utils.data import DataLoader
1010

1111
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
1212
from colossalai.cluster import DistCoordinator
@@ -136,9 +136,11 @@ def save_lora_as_pretrained(
136136
assert isinstance(
137137
peft_model, PeftModel
138138
), "The model doesn't have lora adapters, please enable lora before saving."
139-
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
140-
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x,
141-
peft_model.state_dict()))
139+
return peft_model.save_pretrained(
140+
checkpoint,
141+
safe_serialization=use_safetensors,
142+
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
143+
)
142144

143145

144146
class TorchDDPModel(ModelWrapper):

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,8 @@ def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
957957
assert isinstance(
958958
peft_model, PeftModel
959959
), "The model doesn't have lora adapters, please enable lora before saving."
960-
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors,
961-
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x,
962-
peft_model.state_dict()))
960+
return peft_model.save_pretrained(
961+
checkpoint,
962+
safe_serialization=use_safetensors,
963+
state_dict=tree_map(lambda x: x.data if torch.is_tensor(x) else x, peft_model.state_dict()),
964+
)

0 commit comments

Comments
 (0)