Skip to content

Commit 294713e

Browse files
tastelikefeettastelikefeet
authored andcommitted
Support ddp of unsloth (#5141)
Co-authored-by: tastelikefeet <[email protected]>
1 parent fdd6c2a commit 294713e

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

swift/llm/model/register.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import platform
44
import re
5+
from contextlib import contextmanager
56
from copy import deepcopy
67
from dataclasses import asdict, dataclass, field
78
from functools import partial
@@ -133,10 +134,27 @@ def load_by_unsloth(args):
133134
os.environ['UNSLOTH_DISABLE_STATISTICS'] = '1'
134135
model_info = args.model_info
135136
model_meta = args.model_meta
136-
if model_meta.is_multimodal:
137-
from unsloth import FastVisionModel as UnslothModel
138-
else:
139-
from unsloth import FastLanguageModel as UnslothModel
137+
138+
os.environ['UNSLOTH_IS_PRESENT'] = '1'
139+
140+
@contextmanager
141+
def _patch_distributed_function():
142+
from unsloth_zoo import utils
143+
144+
def distributed_function(n=1, function=None, *args, **kwargs):
145+
return function(*args, **kwargs)
146+
147+
_origin_distributed_function = utils.distributed_function
148+
utils.distributed_function = distributed_function
149+
yield
150+
utils.distributed_function = _origin_distributed_function
151+
152+
with _patch_distributed_function():
153+
if model_meta.is_multimodal:
154+
from unsloth import FastVisionModel as UnslothModel
155+
else:
156+
from unsloth import FastLanguageModel as UnslothModel
157+
140158
model, processor = UnslothModel.from_pretrained(
141159
model_name=args.adapters and args.adapters[0] or args.model_dir,
142160
dtype=args.torch_dtype,

swift/trainers/arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class TrainArgumentsMixin:
2323
per_device_train_batch_size: int = 1
2424
per_device_eval_batch_size: int = 1
2525
gradient_accumulation_steps: Optional[int] = None
26+
tuner_backend: Optional[str] = None
2627

2728
gradient_checkpointing: bool = True
2829
vit_gradient_checkpointing: Optional[bool] = None

swift/trainers/trainers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,8 @@ def _prepare_inputs(self, inputs):
332332
inputs['labels'], logits_to_keep = self.get_logits_to_keep(inputs['labels'])
333333
if logits_to_keep is not None:
334334
inputs['logits_to_keep'] = logits_to_keep
335+
if self.args.tuner_backend == 'unsloth':
336+
inputs['logits_to_keep'] = logits_to_keep.sum()
335337

336338
inputs['compute_loss_func'] = compute_loss_func
337339
inputs['loss_kwargs'] = loss_kwargs

0 commit comments

Comments
 (0)