Skip to content

Commit 3f37045

Browse files
authored
[train] support omni seq_cls (#5329)
1 parent e6709f5 commit 3f37045

File tree

17 files changed

+51
-35
lines changed

17 files changed

+51
-35
lines changed

docs/source/Customization/插件化.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class IA3(Tuner):
120120

121121
@staticmethod
122122
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
123-
model_arch: ModelKeys = MODEL_ARCH_MAPPING[model.model_meta.model_arch]
123+
model_arch: ModelKeys = model.model_meta.model_arch
124124
ia3_config = IA3Config(
125125
target_modules=find_all_linears(model), feedforward_modules='.*' + model_arch.mlp.split('{}.')[1] + '.*')
126126
return get_peft_model(model, ia3_config)

docs/source_en/Customization/Pluginization.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ class IA3(Tuner):
136136

137137
@staticmethod
138138
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
139-
model_arch: ModelKeys = MODEL_ARCH_MAPPING[model.model_meta.model_arch]
139+
model_arch: ModelKeys = model.model_meta.model_arch
140140
ia3_config = IA3Config(
141141
target_modules=find_all_linears(model), feedforward_modules='.*' + model_arch.mlp.split('{}.')[1] + '.*')
142142
return get_peft_model(model, ia3_config)

examples/notebook/qwen2vl-ocr/ocr-sft.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
3737
"\n",
3838
"from swift.llm import (\n",
39-
" get_model_tokenizer, load_dataset, get_template, EncodePreprocessor, get_model_arch,\n",
39+
" get_model_tokenizer, load_dataset, get_template, EncodePreprocessor,\n",
4040
" get_multimodal_target_regex, LazyLLMDataset\n",
4141
")\n",
4242
"from swift.utils import get_logger, get_model_parameter_info, plot_images, seed_everything\n",

examples/train/multimodal/lora_llm_full_vit/custom_plugin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import safetensors.torch
55
import torch
66

7-
from swift.llm import deep_getattr, get_model_arch, get_multimodal_target_regex
7+
from swift.llm import deep_getattr, get_multimodal_target_regex
88
from swift.plugin import Tuner, extra_tuners
99
from swift.tuners import LoraConfig, Swift
1010
from swift.utils import get_logger
@@ -46,14 +46,14 @@ def save_pretrained(
4646
state_dict[n] = p.detach().cpu()
4747
model.save_pretrained(save_directory, state_dict=state_dict, safe_serialization=safe_serialization, **kwargs)
4848
# vit
49-
model_arch = get_model_arch(model.model_meta.model_arch)
49+
model_arch = model.model_meta.model_arch
5050
state_dict = {k: v for k, v in state_dict.items() if is_vit_param(model_arch, k)}
5151
safetensors.torch.save_file(
5252
state_dict, os.path.join(save_directory, 'vit.safetensors'), metadata={'format': 'pt'})
5353

5454
@staticmethod
5555
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
56-
model_arch = get_model_arch(model.model_meta.model_arch)
56+
model_arch = model.model_meta.model_arch
5757
target_regex = get_multimodal_target_regex(model)
5858
logger.info(f'target_regex: {target_regex}')
5959
lora_config = LoraConfig(

swift/llm/argument/base_args/quant_args.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,9 @@ def get_quantization_config(self):
8989
return quantization_config
9090

9191
def get_modules_to_not_convert(self):
92-
from swift.llm import get_model_arch
9392
if not hasattr(self, 'model_meta') or not hasattr(self, 'model_info'):
9493
return None
95-
model_arch = get_model_arch(self.model_meta.model_arch)
94+
model_arch = self.model_meta.model_arch
9695
res = []
9796
if self.model_info.is_moe_model:
9897
res += ['mlp.gate', 'mlp.shared_expert_gate']

swift/llm/argument/tuner_args.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from transformers.utils import strtobool
66

7-
from swift.llm import get_model_arch
87
from swift.utils import get_logger
98

109
logger = get_logger()
@@ -204,7 +203,7 @@ def __post_init__(self):
204203
self.target_modules = self.target_regex
205204

206205
def _init_multimodal_full(self):
207-
model_arch = get_model_arch(self.model_meta.model_arch)
206+
model_arch = self.model_meta.model_arch
208207
if not self.model_meta.is_multimodal or not model_arch or self.train_type != 'full':
209208
return
210209
if self.freeze_llm:

swift/llm/export/quant.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import torch.nn as nn
88
from tqdm import tqdm
99

10-
from swift.llm import (ExportArguments, HfConfigFactory, MaxLengthError, ProcessorMixin, deep_getattr, get_model_arch,
11-
load_dataset, prepare_model_template, save_checkpoint, to_device)
10+
from swift.llm import (ExportArguments, HfConfigFactory, MaxLengthError, ProcessorMixin, deep_getattr, load_dataset,
11+
prepare_model_template, save_checkpoint, to_device)
1212
from swift.utils import get_logger, get_model_parameter_info
1313

1414
logger = get_logger()
@@ -160,7 +160,7 @@ def awq_model_quantize(self) -> None:
160160
self.tokenizer, quant_config=quant_config, n_parallel_calib_samples=args.quant_batch_size)
161161
quantizer.get_calib_dataset = _origin_get_calib_dataset # recover
162162
if self.model.quant_config.modules_to_not_convert:
163-
model_arch = get_model_arch(args.model_meta.model_arch)
163+
model_arch = args.model_meta.model_arch
164164
lm_head_key = getattr(model_arch, 'lm_head', None) or 'lm_head'
165165
if lm_head_key not in self.model.quant_config.modules_to_not_convert:
166166
self.model.quant_config.modules_to_not_convert.append(lm_head_key)
@@ -180,7 +180,7 @@ def _patch_gptq(self):
180180

181181
@staticmethod
182182
def get_block_name_to_quantize(model: nn.Module) -> Optional[str]:
183-
model_arch = get_model_arch(model.model_meta.model_arch)
183+
model_arch = model.model_meta.model_arch
184184
prefix = ''
185185
if hasattr(model_arch, 'language_model'):
186186
assert len(model_arch.language_model) == 1, f'mllm_arch.language_model: {model_arch.language_model}'

swift/llm/model/patcher.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,30 @@ def _check_imports(filename) -> List[str]:
150150
td.check_imports = _old_check_imports
151151

152152

153+
def get_lm_head_model(model, model_meta, lm_heads):
154+
llm_prefix_list = getattr(model_meta.model_arch, 'language_model', None)
155+
prefix_list = []
156+
if llm_prefix_list:
157+
prefix_list = llm_prefix_list[0].split('.')
158+
159+
origin_model = model
160+
current_model = model
161+
for prefix in [None] + prefix_list:
162+
if prefix:
163+
current_model = getattr(current_model, prefix)
164+
for lm_head in lm_heads:
165+
if hasattr(current_model, lm_head):
166+
return current_model
167+
168+
raise ValueError(f'Cannot find the lm_head. model: {origin_model}')
169+
170+
153171
def _patch_sequence_classification(model, model_meta):
154172
hidden_size = HfConfigFactory.get_config_attr(model.config, 'hidden_size')
155173
initializer_range = HfConfigFactory.get_config_attr(model.config, 'initializer_range')
156174

157175
lm_heads = ['lm_head', 'output', 'embed_out', 'output_layer']
158-
llm_model = get_llm_model(model, model_meta=model_meta)
176+
llm_model = get_lm_head_model(model, model_meta, lm_heads)
159177
llm_model.num_labels = model.config.num_labels
160178
llm_model.score = nn.Linear(hidden_size, llm_model.num_labels, bias=False, dtype=llm_model.dtype)
161179
if llm_model.score.weight.device == torch.device('meta'):

swift/llm/model/register.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def register_model(model_meta: ModelMeta, *, exist_ok: bool = False) -> None:
117117
model_type: The unique ID for the model type. Models with the same model_type share
118118
the same architectures, template, get_function, etc.
119119
"""
120+
from .model_arch import get_model_arch
120121
model_type = model_meta.model_type
121122
if not exist_ok and model_type in MODEL_MAPPING:
122123
raise ValueError(f'The `{model_type}` has already been registered in the MODEL_MAPPING.')
@@ -125,6 +126,8 @@ def register_model(model_meta: ModelMeta, *, exist_ok: bool = False) -> None:
125126
model_meta.is_multimodal = True
126127
if model_type in RMModelType.__dict__:
127128
model_meta.is_reward = True
129+
if model_meta.model_arch:
130+
model_meta.model_arch = get_model_arch(model_meta.model_arch)
128131
MODEL_MAPPING[model_type] = model_meta
129132

130133

swift/llm/model/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,8 @@ def git_clone_github(github_url: str,
351351

352352

353353
def get_llm_model(model: torch.nn.Module, model_meta=None):
354-
from swift import SwiftModel
354+
from swift.tuners import SwiftModel
355355
from peft import PeftModel
356-
from swift.llm import get_model_arch
357356
from accelerate.utils import extract_model_from_parallel
358357
model = extract_model_from_parallel(model)
359358

@@ -362,7 +361,7 @@ def get_llm_model(model: torch.nn.Module, model_meta=None):
362361
if model_meta is None:
363362
model_meta = model.model_meta
364363

365-
llm_prefix = getattr(get_model_arch(model_meta.model_arch), 'language_model', None)
364+
llm_prefix = getattr(model_meta.model_arch, 'language_model', None)
366365
if llm_prefix:
367366
llm_model = deep_getattr(model, llm_prefix[0])
368367
else:

0 commit comments

Comments
 (0)