|
2 | 2 | import os |
3 | 3 | import platform |
4 | 4 | import re |
| 5 | +from contextlib import contextmanager |
5 | 6 | from copy import deepcopy |
6 | 7 | from dataclasses import asdict, dataclass, field |
7 | 8 | from functools import partial |
@@ -133,10 +134,27 @@ def load_by_unsloth(args): |
133 | 134 | os.environ['UNSLOTH_DISABLE_STATISTICS'] = '1' |
134 | 135 | model_info = args.model_info |
135 | 136 | model_meta = args.model_meta |
136 | | - if model_meta.is_multimodal: |
137 | | - from unsloth import FastVisionModel as UnslothModel |
138 | | - else: |
139 | | - from unsloth import FastLanguageModel as UnslothModel |
| 137 | + |
| 138 | + os.environ['UNSLOTH_IS_PRESENT'] = '1' |
| 139 | + |
| 140 | + @contextmanager |
| 141 | + def _patch_distributed_function(): |
| 142 | + from unsloth_zoo import utils |
| 143 | + |
| 144 | + def distributed_function(n=1, function=None, *args, **kwargs): |
| 145 | + return function(*args, **kwargs) |
| 146 | + |
| 147 | + _origin_distributed_function = utils.distributed_function |
| 148 | + utils.distributed_function = distributed_function |
| 149 | + yield |
| 150 | + utils.distributed_function = _origin_distributed_function |
| 151 | + |
| 152 | + with _patch_distributed_function(): |
| 153 | + if model_meta.is_multimodal: |
| 154 | + from unsloth import FastVisionModel as UnslothModel |
| 155 | + else: |
| 156 | + from unsloth import FastLanguageModel as UnslothModel |
| 157 | + |
140 | 158 | model, processor = UnslothModel.from_pretrained( |
141 | 159 | model_name=args.adapters and args.adapters[0] or args.model_dir, |
142 | 160 | dtype=args.torch_dtype, |
|
0 commit comments