Skip to content

Commit da4baf5

Browse files
committed
Fix qwen-audio inference bug (#204)
1 parent 9873fdc commit da4baf5

File tree

6 files changed

+66
-47
lines changed

6 files changed

+66
-47
lines changed

docs/source/LLM/LLM推理文档.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -274,24 +274,24 @@ template = get_template(template_type, tokenizer)
274274

275275
seed_everything(42)
276276
query = tokenizer.from_list_format([
277-
{'audio': 'demo.wav'},
278-
{'text': '请将语音转成文本'},
277+
{'audio': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Audio/1272-128104-0000.flac'},
278+
{'text': 'what does the person say?'},
279279
])
280280
response, history = inference(model, template, query)
281281
print(f'query: {query}')
282282
print(f'response: {response}')
283-
query = '这句话一般在什么语境下使用'
283+
query = 'Find the start time and end time of the word "middle classes'
284284
response, history = inference(model, template, query, history)
285285
print(f'query: {query}')
286286
print(f'response: {response}')
287287
print(f'history: {history}')
288-
"""
289-
query: Audio 1:<audio>demo.wav</audio>
290-
请将语音转成文本
291-
response: 好的,这是转成的文本:"每一天都要快乐哦"。
292-
query: 这句话一般在什么语境下使用
293-
response: 这句话一般在表达祝福或者鼓励的时候使用,比如在朋友或者亲人过生日的时候说"每一天都要快乐哦",表达祝福的意思。
294-
history: [('Audio 1:<audio>demo.wav</audio>\n请将语音转成文本', '好的,这是转成的文本:"每一天都要快乐哦"。'), ('这句话一般在什么语境下使用', '这句话一般在表达祝福或者鼓励的时候使用,比如在朋友或者亲人过生日的时候说"每一天都要快乐哦",表达祝福的意思。')]
288+
"""Out[0]
289+
query: Audio 1:<audio>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Audio/1272-128104-0000.flac</audio>
290+
what does the person say?
291+
response: The person says: "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel".
292+
query: Find the start time and end time of the word "middle classes
293+
response: The word "middle classes" starts at <|2.33|> seconds and ends at <|3.26|> seconds.
294+
history: [('Audio 1:<audio>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Audio/1272-128104-0000.flac</audio>\nwhat does the person say?', 'The person says: "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel".'), ('Find the start time and end time of the word "middle classes', 'The word "middle classes" starts at <|2.33|> seconds and ends at <|3.26|> seconds.')]
295295
"""
296296
```
297297

docs/source/LLM/命令行参数.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,14 @@
8686
- `--template_type`: 默认值为`'AUTO'`, 具体的参数介绍可以在`sft.sh命令行参数`中查看.
8787
- `--ckpt_dir`: 必填项, 值为SFT阶段保存的checkpoint路径, e.g. `'/path/to/your/vx_xxx/checkpoint-xxx'`.
8888
- `--load_args_from_ckpt_dir`: 是否从`ckpt_dir``sft_args.json`文件中读取配置信息. 默认是`True`.
89-
- `--eval_human`: 使用数据集中的验证集部分进行评估还是使用人工的方式评估, 默认值为`False`.
89+
- `--load_dataset_config`: 该参数只有在`--load_args_from_ckpt_dir true`时才生效. 即是否从`ckpt_dir``sft_args.json`文件中读取数据集相关的配置信息. 默认为`True`.
90+
- `--eval_human`: 使用数据集中的验证集部分进行评估还是使用人工的方式评估. 默认值为`None`, 如有传入数据集, 则设置为True, 否则设置为False.
9091
- `--seed`: 默认值为`42`, 具体的参数介绍可以在`sft.sh命令行参数`中查看.
9192
- `--dtype`: 默认值为`'AUTO`, 具体的参数介绍可以在`sft.sh命令行参数`中查看.
92-
- `--dataset`: 默认值为`'blossom-math-zh'`, 具体的参数介绍可以在`sft.sh命令行参数`中查看. 该参数只有在`eval_human`设置为False时才生效.
93-
- `--dataset_seed`: 默认值为`42`, 具体的参数介绍可以在`sft.sh命令行参数`中查看. 该参数只有在`eval_human`设置为False时才生效.
94-
- `--dataset_test_ratio`: 默认值为`0.01`, 具体的参数介绍可以在`sft.sh命令行参数`中查看. 该参数只有在`eval_human`设置为False时才生效.
95-
- `--val_dataset_sample`: 表示想要评估和展示的验证集的数量, 默认值为`10`. 该参数只有在`eval_human`设置为False时才生效.
93+
- `--dataset`: 默认值为`'blossom-math-zh'`, 具体的参数介绍可以在`sft.sh命令行参数`中查看. 该参数在`eval_human`设置为True时不生效.
94+
- `--dataset_seed`: 默认值为`42`, 具体的参数介绍可以在`sft.sh命令行参数`中查看. 该参数在`eval_human`设置为True时不生效.
95+
- `--dataset_test_ratio`: 默认值为`0.01`, 具体的参数介绍可以在`sft.sh命令行参数`中查看. 该参数在`eval_human`设置为True时不生效.
96+
- `--val_dataset_sample`: 表示想要评估和展示的验证集的数量, 默认值为`10`. 该参数在`eval_human`设置为True时不生效.
9697
- `--system`: 默认值为`None`. 具体的参数介绍可以在`sft.sh命令行参数`中查看.
9798
- `--max_length`: 默认值为`2048`. 具体的参数介绍可以在`sft.sh命令行参数`中查看.
9899
- `--truncation_strategy`: 默认是`'delete'`. 具体的参数介绍可以在`sft.sh命令行参数`中查看.

swift/llm/infer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def llm_infer(args: InferArguments) -> None:
161161
jsonl_path = os.path.join(args.ckpt_dir, f'infer_result_{time}.jsonl')
162162
if args.eval_human:
163163
input_mode: Literal['S', 'M'] = 'S'
164-
logger.info('Input `exit` to exit the conversation.')
164+
logger.info('Input `exit` or `quit` to exit the conversation.')
165165
logger.info('Input `multi-line` to switch to multi-line input mode.')
166166
if template.support_multi_round:
167167
logger.info('Input `clear` to clear the history.')
@@ -174,7 +174,7 @@ def llm_infer(args: InferArguments) -> None:
174174
query = input('<<< ')
175175
else:
176176
query = read_multi_line()
177-
if query.strip().lower() == 'exit':
177+
if query.strip().lower() in {'exit', 'quit'}:
178178
break
179179
elif query.strip().lower() == 'clear':
180180
history = []
@@ -186,7 +186,7 @@ def llm_infer(args: InferArguments) -> None:
186186
'Input `single-line` to switch to single-line input mode.')
187187
continue
188188
if input_mode == 'M' and query.strip().lower() == 'single-line':
189-
input_mode == 'S'
189+
input_mode = 'S'
190190
continue
191191
if not template.support_multi_round:
192192
history = []

swift/llm/utils/argument.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def __post_init__(self) -> None:
267267
if self.logging_dir is None:
268268
self.logging_dir = f'{self.output_dir}/runs'
269269
if self.report_to is None:
270-
self.report_to == ['all']
270+
self.report_to = ['all']
271271
if self.gradient_accumulation_steps is None:
272272
self.gradient_accumulation_steps = math.ceil(16 / self.batch_size
273273
/ world_size)
@@ -296,7 +296,7 @@ class InferArguments:
296296
default=None, metadata={'help': '/path/to/your/vx_xxx/checkpoint-xxx'})
297297
load_args_from_ckpt_dir: bool = True
298298
load_dataset_config: bool = True
299-
eval_human: bool = False # False: eval val_dataset
299+
eval_human: Optional[bool] = None # False: eval val_dataset
300300

301301
seed: int = 42
302302
dtype: str = field(
@@ -363,17 +363,22 @@ def __post_init__(self) -> None:
363363
if self.template_type == 'AUTO':
364364
self.template_type = get_default_template_type(self.model_type)
365365
logger.info(f'Setting template_type: {self.template_type}')
366-
if not self.eval_human:
367-
if isinstance(self.dataset, str):
368-
self.dataset = [self.dataset]
369-
elif self.dataset is None:
370-
self.dataset = []
371-
if len(self.dataset) == 0:
372-
if (len(self.custom_train_dataset_path) == 0
373-
and len(self.custom_val_dataset_path) == 0):
374-
raise ValueError(
375-
f'self.dataset: {self.dataset}. Please set `--eval_human true` or `--dataset xxx`'
376-
)
366+
if isinstance(self.dataset, str):
367+
self.dataset = [self.dataset]
368+
elif self.dataset is None:
369+
self.dataset = []
370+
if (len(self.dataset) == 0 and len(self.custom_train_dataset_path) == 0
371+
and len(self.custom_val_dataset_path) == 0):
372+
if self.eval_human is None:
373+
self.eval_human = True
374+
logger.info(f'Setting self.eval_human: {self.eval_human}')
375+
if not self.eval_human:
376+
raise ValueError(
377+
f'self.dataset: {self.dataset}. Please set `--eval_human true` or `--dataset xxx`'
378+
)
379+
elif self.eval_human is None:
380+
self.eval_human = False
381+
logger.info(f'Setting self.eval_human: {self.eval_human}')
377382

378383
self.bnb_4bit_compute_dtype, self.load_in_4bit, self.load_in_8bit = select_bnb(
379384
self)

swift/llm/utils/model.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,19 @@ def fix_qwen_inplace_bug(model) -> None:
801801
first_drop.__old_forward = __old_forward
802802

803803

804+
def _qwen_vl_audio_decode(self,
805+
*args,
806+
skip_special_tokens=False,
807+
**kwargs) -> str:
808+
if skip_special_tokens:
809+
token_ids = kwargs['token_ids']
810+
while len(token_ids) > 0 and token_ids[-1] in {151645, 151643}:
811+
token_ids.pop()
812+
return self._old_decode(*args, skip_special_tokens=False, **kwargs)
813+
else:
814+
return self._old_decode(*args, skip_special_tokens=False, **kwargs)
815+
816+
804817
@register_model(
805818
ModelType.qwen_vl_chat,
806819
'qwen/Qwen-VL-Chat',
@@ -838,21 +851,9 @@ def get_model_tokenizer_qwen_vl(model_dir: str,
838851
load_model, **kwargs)
839852
if model is not None:
840853
fix_qwen_inplace_bug(model)
841-
842-
_old_decode = tokenizer._decode
843-
844-
def _new_decode(*args, skip_special_tokens=False, **kwargs) -> str:
845-
if skip_special_tokens:
846-
token_ids = kwargs['token_ids']
847-
while len(token_ids) > 0 and token_ids[-1] in {151645, 151643}:
848-
token_ids.pop()
849-
return _old_decode(*args, skip_special_tokens=False, **kwargs)
850-
else:
851-
return _old_decode(*args, skip_special_tokens=False, **kwargs)
852-
853854
if not hasattr(tokenizer, '_old_decode'): # avoid double patching
854-
tokenizer._old_decode = _old_decode
855-
tokenizer._decode = _new_decode
855+
tokenizer._old_decode = tokenizer._decode
856+
tokenizer._decode = MethodType(_qwen_vl_audio_decode, tokenizer)
856857

857858
return model, tokenizer
858859

@@ -888,6 +889,10 @@ def get_model_tokenizer_qwen_audio(model_dir: str,
888889
load_model, **kwargs)
889890
if model is not None:
890891
fix_qwen_inplace_bug(model)
892+
if not hasattr(tokenizer, '_old_decode'): # avoid double patching
893+
tokenizer._old_decode = tokenizer._decode
894+
tokenizer._decode = MethodType(_qwen_vl_audio_decode, tokenizer)
895+
891896
return model, tokenizer
892897

893898

swift/llm/utils/template.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from copy import deepcopy
33
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
44

5+
import torch
56
from torch import Tensor
67
from transformers import PreTrainedTokenizerBase, StoppingCriteria
78

@@ -131,8 +132,15 @@ def _encode_context_list(
131132
elif isinstance(context, str):
132133
if (getattr(tokenizer, 'model_type', '').startswith('qwen-audio')):
133134
audio_info = get_audio_info(tokenizer, context=context)
134-
assert 'audio_info' not in kwargs
135-
kwargs['audio_info'] = audio_info
135+
old_audio_info = kwargs.get('audio_info')
136+
if old_audio_info is None:
137+
kwargs['audio_info'] = audio_info
138+
elif audio_info is not None:
139+
for k in ['input_audios', 'input_audio_lengths']:
140+
old_audio_info[k] = torch.concat(
141+
[old_audio_info[k], audio_info[k]], dim=0)
142+
for k in ['audio_span_tokens', 'audio_urls']:
143+
old_audio_info[k] = old_audio_info[k] + audio_info[k]
136144
token_list = tokenizer(
137145
context,
138146
return_attention_mask=False,

0 commit comments

Comments
 (0)