Skip to content

Commit 9c79b38

Browse files
tastelikefeettastelikefeet
authored andcommitted
fix (#5147)
Co-authored-by: tastelikefeet <[email protected]>
1 parent 294713e commit 9c79b38

File tree

4 files changed

+22
-13
lines changed

4 files changed

+22
-13
lines changed

swift/cli/sft.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
import os
33

4+
if int(os.environ.get('UNSLOTH_PATCH_TRL', '0')) != 0:
5+
import unsloth
6+
47
from swift.llm import sft_main
58

69
if __name__ == '__main__':

swift/llm/model/register.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -139,31 +139,35 @@ def load_by_unsloth(args):
139139

140140
@contextmanager
141141
def _patch_distributed_function():
142-
from unsloth_zoo import utils
142+
from unsloth_zoo import utils, compiler
143143

144144
def distributed_function(n=1, function=None, *args, **kwargs):
145145
return function(*args, **kwargs)
146146

147147
_origin_distributed_function = utils.distributed_function
148148
utils.distributed_function = distributed_function
149+
compiler.distributed_function = distributed_function
149150
yield
150151
utils.distributed_function = _origin_distributed_function
152+
compiler.distributed_function = _origin_distributed_function
151153

152154
with _patch_distributed_function():
153155
if model_meta.is_multimodal:
154156
from unsloth import FastVisionModel as UnslothModel
157+
elif model_info.is_moe_model:
158+
from unsloth import FastModel as UnslothModel
155159
else:
156160
from unsloth import FastLanguageModel as UnslothModel
157161

158-
model, processor = UnslothModel.from_pretrained(
159-
model_name=args.adapters and args.adapters[0] or args.model_dir,
160-
dtype=args.torch_dtype,
161-
max_seq_length=args.max_length,
162-
full_finetuning=args.train_type == 'full',
163-
load_in_4bit=args.quant_bits == 4,
164-
load_in_8bit=args.quant_bits == 8,
165-
device_map=args.device_map,
166-
)
162+
model, processor = UnslothModel.from_pretrained(
163+
model_name=args.adapters and args.adapters[0] or args.model_dir,
164+
dtype=args.torch_dtype,
165+
max_seq_length=args.max_length,
166+
full_finetuning=args.train_type == 'full',
167+
load_in_4bit=args.quant_bits == 4,
168+
load_in_8bit=args.quant_bits == 8,
169+
device_map=args.device_map,
170+
)
167171
if isinstance(model, PeftModel):
168172
base_model = model.model
169173
else:

swift/trainers/trainers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def _prepare_inputs(self, inputs):
333333
if logits_to_keep is not None:
334334
inputs['logits_to_keep'] = logits_to_keep
335335
if self.args.tuner_backend == 'unsloth':
336-
inputs['logits_to_keep'] = logits_to_keep.sum()
336+
inputs['logits_to_keep'] = int(logits_to_keep.sum())
337337

338338
inputs['compute_loss_func'] = compute_loss_func
339339
inputs['loss_kwargs'] = loss_kwargs
@@ -387,8 +387,10 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
387387
if getattr(self.args, 'average_tokens_across_devices', False) and self.model_accepts_loss_kwargs:
388388
loss *= self.accelerator.num_processes
389389

390-
if outputs.logits is not None and labels is not None and not return_outputs:
390+
if (outputs.logits is not None and labels is not None and not return_outputs
391+
and self.args.tuner_backend != 'unsloth'):
391392
# Liger does not have logits
393+
# Unsloth has a bug with output logits
392394
self._compute_acc(outputs, labels)
393395
return (loss, outputs) if return_outputs else loss
394396

swift/tuners/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ def mark_trainable_callback(model, cfg):
526526

527527
def save_pretrained(self,
528528
save_directory: str,
529-
safe_serialization: bool = False,
529+
safe_serialization: bool = True,
530530
adapter_name: Union[str, List[str]] = None,
531531
**kwargs):
532532
"""Save the adapters to a local directory.

0 commit comments

Comments
 (0)