Skip to content

Commit eb03170

Browse files
committed
[bugfix] fix keye_vl (#5848)
1 parent 5ed5913 commit eb03170

File tree

5 files changed

+70
-23
lines changed

5 files changed

+70
-23
lines changed

examples/models/keye/train.sh

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# 24GiB
2+
CUDA_VISIBLE_DEVICES=0 \
3+
swift sft \
4+
--model Kwai-Keye/Keye-VL-1_5-8B \
5+
--dataset 'AI-ModelScope/LaTeX_OCR:human_handwrite#20000' \
6+
--split_dataset_ratio 0.01 \
7+
--train_type lora \
8+
--torch_dtype bfloat16 \
9+
--num_train_epochs 1 \
10+
--per_device_train_batch_size 1 \
11+
--per_device_eval_batch_size 1 \
12+
--learning_rate 1e-4 \
13+
--lora_rank 8 \
14+
--lora_alpha 32 \
15+
--target_modules all-linear \
16+
--freeze_vit true \
17+
--gradient_accumulation_steps 16 \
18+
--eval_steps 50 \
19+
--save_steps 50 \
20+
--save_total_limit 2 \
21+
--logging_steps 5 \
22+
--max_length 2048 \
23+
--output_dir output \
24+
--warmup_ratio 0.05 \
25+
--dataset_num_proc 4 \
26+
--dataloader_num_workers 4

swift/llm/model/model/qwen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,8 @@ def _new_read_video_decord(ele: dict):
685685
backends = getattr(vision_process, 'VIDEO_READER_BACKENDS', None)
686686
if isinstance(backends, dict):
687687
backends['decord'] = _new_read_video_decord
688+
elif backends is None: # keye_vl
689+
vision_process._read_video_decord = _new_read_video_decord
688690
vision_process._patch = True
689691
return res
690692

swift/llm/template/template/kwai.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,12 @@ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int
4141
video = inputs.videos[index]
4242
if os.path.isdir(video):
4343
video = [os.path.join(video, fname) for fname in os.listdir(video)]
44-
video, video_kwargs = fetch_video({'video': video}, return_video_sample_fps=True)
44+
video, video_kwargs = fetch_video({'video': video})
4545
if isinstance(video, torch.Tensor):
4646
video = video.to(torch.uint8)
4747
inputs.videos[index] = video
48-
inputs.mm_processor_kwargs.setdefault('fps', []).append(video_kwargs)
48+
for k, v in video_kwargs.items():
49+
inputs.mm_processor_kwargs.setdefault(k, []).append(v)
4950
return ['<|vision_start|><|video_pad|><|vision_end|>']
5051

5152
def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
@@ -62,25 +63,24 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
6263
media_inputs = processor.image_processor(images=mm_data, return_tensors='pt', do_resize=False)
6364
media_grid_thw = media_inputs['image_grid_thw']
6465
else:
65-
kwargs = {}
66-
if hasattr(processor, 'video_processor'):
67-
processor_func = processor.video_processor
68-
else:
69-
processor_func = processor.image_processor
70-
kwargs['images'] = None
71-
media_inputs = processor_func(videos=mm_data, return_tensors='pt', do_resize=False, **kwargs)
66+
split_token = self._tokenize('\n')[0]
67+
media_inputs = processor(
68+
text=['\n'.join(['<|video_pad|>'] * len(mm_data))],
69+
videos=mm_data,
70+
return_tensors='pt',
71+
**inputs.mm_processor_kwargs)
72+
splited_tokens = self._split_list(media_inputs['input_ids'][0].tolist(), split_token)
7273
media_grid_thw = media_inputs['video_grid_thw']
7374
media_token = self.video_token_id
74-
fps = inputs.mm_processor_kwargs['fps']
75-
media_inputs['second_per_grid_ts'] = [
76-
processor.image_processor.temporal_patch_size / tmp for tmp in fps
77-
]
7875
idx_list = findall(input_ids, media_token)
7976
merge_length = processor.image_processor.merge_size**2
8077

8178
def _get_new_tokens(i):
82-
token_len = (media_grid_thw[i].prod() // merge_length)
83-
return [media_token] * token_len
79+
if media_type == 'images':
80+
token_len = (media_grid_thw[i].prod() // merge_length)
81+
return [media_token] * token_len
82+
else:
83+
return splited_tokens[i]
8484

8585
input_ids, labels, loss_scale = self._extend_tokens(input_ids, labels, loss_scale, idx_list,
8686
_get_new_tokens)
@@ -291,6 +291,14 @@ def _data_collator_mm_data(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
291291

292292
# Register the Keye VL template
293293
register_template(KeyeTemplateMeta(MLLMTemplateType.keye_vl, template_cls=KeyeVLTemplate))
294+
295+
296+
class KeyeVL1_5Template(KeyeVLTemplate):
297+
298+
def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
299+
return super(KeyeVLTemplate, self)._post_encode(model, inputs)
300+
301+
294302
register_template(
295303
KeyeTemplateMeta(
296-
MLLMTemplateType.keye_vl_1_5, template_cls=KeyeVLTemplate, default_system='You are a helpful assistant.'))
304+
MLLMTemplateType.keye_vl_1_5, template_cls=KeyeVL1_5Template, default_system='You are a helpful assistant.'))

tests/test_align/test_template/test_video.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,25 @@ def test_glm4_5v():
173173

174174

175175
def test_keye_vl():
176-
pt_engine = PtEngine('Kwai-Keye/Keye-VL-8B-Preview', attn_impl='flash_attention_2')
177-
messages = [{'role': 'user', 'content': '<video>What happened in the video?'}]
176+
pt_engine = PtEngine('Kwai-Keye/Keye-VL-8B-Preview')
177+
messages = [{'role': 'user', 'content': '<video>Describe this video.'}]
178178
videos = ['https://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/baby.mp4']
179179
response = _infer_model(pt_engine, messages=messages, videos=videos)
180180
pt_engine.default_template.template_backend = 'jinja'
181181
response2 = _infer_model(pt_engine, messages=messages, videos=videos)
182182
assert response == response2
183183

184184

185+
def test_keye_vl_1_5():
186+
pt_engine = PtEngine('Kwai-Keye/Keye-VL-1_5-8B')
187+
messages = [{'role': 'user', 'content': '<video>Describe this video.'}]
188+
videos = ['https://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/baby.mp4']
189+
response = _infer_model(pt_engine, messages=messages, videos=videos)
190+
assert response[:200] == ('The video features a young child sitting on a bed, engrossed in '
191+
'reading a book. The child is wearing a light blue sleeveless top and pink '
192+
'pants. The book appears to be a hardcover with illustrations, ')
193+
194+
185195
def test_ovis2_5():
186196
pt_engine = PtEngine('AIDC-AI/Ovis2.5-2B')
187197
messages = [{'role': 'user', 'content': '<video>Describe this video in detail.'}]
@@ -241,9 +251,10 @@ def test_minicpmv4_5():
241251
# test_qwen2_5_vl()
242252
# test_qwen2_5_omni()
243253
# test_glm4_1v() # bug now, wait model fix
244-
# test_keye_vl()
254+
test_keye_vl()
255+
test_keye_vl_1_5()
245256
# test_glm4_5v()
246257
# test_ovis2_5()
247258
# test_interns1()
248259
# test_internvl3_5()
249-
test_minicpmv4_5()
260+
# test_minicpmv4_5()

tests/test_align/test_template/test_vision.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -790,13 +790,13 @@ def test_minicpmv4_5():
790790
# test_kimi_vl_thinking()
791791
# test_glm4_1v()
792792
# test_gemma3n()
793-
# test_keye_vl()
793+
test_keye_vl()
794794
# test_dots_ocr()
795795
# test_glm4_5v()
796796
# test_interns1()
797797
# test_internvl3_5()
798798
# test_minicpmv4_5()
799-
# test_keye_vl_1_5()
799+
test_keye_vl_1_5()
800800
# test_internvl3_hf()
801801
# test_internvl3_5_hf()
802-
test_internvl_gpt_hf()
802+
# test_internvl_gpt_hf()

0 commit comments

Comments
 (0)