| 
 | 1 | +import os  | 
 | 2 | +import sys  | 
 | 3 | + | 
 | 4 | +import requests  | 
 | 5 | +from modelscope import snapshot_download  | 
 | 6 | +from qwen_omni_utils import process_mm_info  | 
 | 7 | +from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor  | 
 | 8 | + | 
 | 9 | +from swift.llm import InferRequest, PtEngine, RequestConfig  | 
 | 10 | + | 
 | 11 | +sys.path.append('examples/custom/my_qwen2_5_omni')  | 
 | 12 | + | 
 | 13 | + | 
 | 14 | +def infer_hf():  | 
 | 15 | +    model_dir = snapshot_download('Qwen/Qwen2.5-Omni-7B')  | 
 | 16 | +    model = Qwen2_5OmniForConditionalGeneration.from_pretrained(  | 
 | 17 | +        model_dir, torch_dtype='auto', device_map='auto', attn_implementation='flash_attention_2')  | 
 | 18 | +    processor = Qwen2_5OmniProcessor.from_pretrained(model_dir)  | 
 | 19 | +    # Use decord to read video (url not yet supported)  | 
 | 20 | +    resp = requests.get('https://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/baby.mp4')  | 
 | 21 | +    with open('_baby.mp4', 'wb') as f:  | 
 | 22 | +        f.write(resp.content)  | 
 | 23 | + | 
 | 24 | +    conversation = [  | 
 | 25 | +        {  | 
 | 26 | +            'role':  | 
 | 27 | +            'user',  | 
 | 28 | +            'content': [  | 
 | 29 | +                {  | 
 | 30 | +                    'type': 'video',  | 
 | 31 | +                    'video': '_baby.mp4'  | 
 | 32 | +                },  | 
 | 33 | +                {  | 
 | 34 | +                    'type': 'image',  | 
 | 35 | +                    'image': 'http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png'  | 
 | 36 | +                },  | 
 | 37 | +                {  | 
 | 38 | +                    'type': 'text',  | 
 | 39 | +                    'text': 'Describe the video and image.'  | 
 | 40 | +                },  | 
 | 41 | +            ],  | 
 | 42 | +        },  | 
 | 43 | +    ]  | 
 | 44 | + | 
 | 45 | +    USE_AUDIO_IN_VIDEO = False  | 
 | 46 | +    text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)  | 
 | 47 | +    audios, images, videos = process_mm_info(conversation, use_audio_in_video=USE_AUDIO_IN_VIDEO)  | 
 | 48 | +    inputs = processor(  | 
 | 49 | +        text=text,  | 
 | 50 | +        audio=audios,  | 
 | 51 | +        images=images,  | 
 | 52 | +        videos=videos,  | 
 | 53 | +        return_tensors='pt',  | 
 | 54 | +        padding=True,  | 
 | 55 | +        use_audio_in_video=USE_AUDIO_IN_VIDEO)  | 
 | 56 | +    inputs = inputs.to(model.device).to(model.dtype)  | 
 | 57 | +    text_ids = model.generate(  | 
 | 58 | +        **inputs, use_audio_in_video=USE_AUDIO_IN_VIDEO, thinker_do_sample=False, return_audio=False)  | 
 | 59 | +    text = processor.batch_decode(  | 
 | 60 | +        text_ids[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True, clean_up_tokenization_spaces=False)  | 
 | 61 | +    return inputs['input_ids'][0].tolist(), text[0]  | 
 | 62 | + | 
 | 63 | + | 
 | 64 | +def test_my_qwen2_5_omni():  | 
 | 65 | +    engine = PtEngine('Qwen/Qwen2.5-Omni-7B', model_type='my_qwen2_5_omni', attn_impl='flash_attention_2')  | 
 | 66 | +    infer_request = InferRequest(  | 
 | 67 | +        messages=[{  | 
 | 68 | +            'role': 'user',  | 
 | 69 | +            'content': '<video><image>Describe the video and image.',  | 
 | 70 | +        }],  | 
 | 71 | +        videos=['https://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/baby.mp4'],  | 
 | 72 | +        images=['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png'],  | 
 | 73 | +    )  | 
 | 74 | +    request_config = RequestConfig(temperature=0, max_tokens=512)  | 
 | 75 | +    input_ids = engine.default_template.encode(infer_request)['input_ids']  | 
 | 76 | +    resp_list = engine.infer([infer_request], request_config)  | 
 | 77 | +    resp = resp_list[0].choices[0].message.content  | 
 | 78 | +    return input_ids, resp  | 
 | 79 | + | 
 | 80 | + | 
 | 81 | +if __name__ == '__main__':  | 
 | 82 | +    import my_register  | 
 | 83 | +    # Enable debug mode, will print input_ids and generate_ids from `PtEngine.infer`  | 
 | 84 | +    os.environ['SWIFT_DEBUG'] = '1'  | 
 | 85 | +    input_ids_hf, response_hf = infer_hf()  | 
 | 86 | +    input_ids_swift, response_swift = test_my_qwen2_5_omni()  | 
 | 87 | +    # Test input_ids and response alignment  | 
 | 88 | +    assert input_ids_hf == input_ids_swift  | 
 | 89 | +    assert response_hf == response_swift  | 
0 commit comments