|
18 | 18 |
|
19 | 19 | import pytest |
20 | 20 | from packaging.version import Version |
21 | | -from transformers import AutoModelForCausalLM, AutoTokenizer |
| 21 | +from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer |
22 | 22 | from transformers.testing_utils import torch_device |
23 | 23 |
|
24 | 24 | from trl.generation.vllm_client import VLLMClient |
|
31 | 31 | kill_process, |
32 | 32 | require_3_accelerators, |
33 | 33 | require_torch_multi_accelerator, |
| 34 | + require_vision, |
34 | 35 | require_vllm, |
35 | 36 | ) |
36 | 37 |
|
@@ -874,3 +875,98 @@ def teardown_class(cls): |
874 | 875 | # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to |
875 | 876 | # kill the server process and its children explicitly. |
876 | 877 | kill_process(cls.server_process) |
| 878 | + |
| 879 | + |
| 880 | +@pytest.mark.slow |
| 881 | +@require_vllm |
| 882 | +@require_vision |
| 883 | +class TestVLLMClientServerVLM(TrlTestCase): |
| 884 | + model_id = "Qwen/Qwen2.5-VL-3B-Instruct" |
| 885 | + |
| 886 | + @classmethod |
| 887 | + def setup_class(cls): |
| 888 | + # Start the server process |
| 889 | + cls.server_process = subprocess.Popen( |
| 890 | + ["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE |
| 891 | + ) |
| 892 | + |
| 893 | + # Initialize the client (no communicator needed for generation-only tests) |
| 894 | + cls.client = VLLMClient(connection_timeout=240, host="localhost") |
| 895 | + |
| 896 | + def test_generate_with_token_ids_and_image(self): |
| 897 | + from PIL import Image |
| 898 | + |
| 899 | + processor = AutoProcessor.from_pretrained(self.model_id) |
| 900 | + image1 = Image.new("RGB", (64, 64), color="red") |
| 901 | + image2 = Image.new("RGB", (64, 64), color="blue") |
| 902 | + image3 = Image.new("RGB", (64, 64), color="green") |
| 903 | + messages = [ |
| 904 | + [ |
| 905 | + { |
| 906 | + "role": "user", |
| 907 | + "content": [ |
| 908 | + {"type": "image", "image": image1}, |
| 909 | + {"type": "image", "image": image2}, |
| 910 | + {"type": "text", "text": "What are the differences between these two images?"}, |
| 911 | + ], |
| 912 | + } |
| 913 | + ], |
| 914 | + [ |
| 915 | + { |
| 916 | + "role": "user", |
| 917 | + "content": [ |
| 918 | + {"type": "image", "image": image3}, |
| 919 | + {"type": "text", "text": "What is the color of this image?"}, |
| 920 | + ], |
| 921 | + } |
| 922 | + ], |
| 923 | + ] |
| 924 | + prompt_token_ids = processor.apply_chat_template( |
| 925 | + conversation=messages, tokenize=True, add_generation_prompt=True |
| 926 | + ) |
| 927 | + outputs = self.client.generate(prompt_token_ids, images=[[image1, image2], [image3]], max_tokens=64) |
| 928 | + prompt_ids = outputs["prompt_ids"] |
| 929 | + completion_ids = outputs["completion_ids"] |
| 930 | + |
| 931 | + assert len(prompt_ids) == 2 |
| 932 | + assert len(completion_ids) == 2 |
| 933 | + assert all(isinstance(tok, int) for tok in prompt_ids[0]) |
| 934 | + assert all(isinstance(tok, int) for tok in completion_ids[0]) |
| 935 | + |
| 936 | + def test_generate_with_token_ids_mixed_images(self): |
| 937 | + """Test a batch where one prompt has an image and the other does not.""" |
| 938 | + from PIL import Image |
| 939 | + |
| 940 | + processor = AutoProcessor.from_pretrained(self.model_id) |
| 941 | + image = Image.new("RGB", (64, 64), color="red") |
| 942 | + messages = [ |
| 943 | + [ |
| 944 | + { |
| 945 | + "role": "user", |
| 946 | + "content": [{"type": "image", "image": image}, {"type": "text", "text": "Describe this image."}], |
| 947 | + } |
| 948 | + ], |
| 949 | + [ |
| 950 | + { |
| 951 | + "role": "user", |
| 952 | + "content": [{"type": "text", "text": "What is 1+1?"}], |
| 953 | + } |
| 954 | + ], |
| 955 | + ] |
| 956 | + prompt_token_ids = processor.apply_chat_template( |
| 957 | + conversation=messages, tokenize=True, add_generation_prompt=True |
| 958 | + ) |
| 959 | + outputs = self.client.generate(prompt_token_ids, images=[[image], None], max_tokens=64) |
| 960 | + prompt_ids = outputs["prompt_ids"] |
| 961 | + completion_ids = outputs["completion_ids"] |
| 962 | + |
| 963 | + assert len(prompt_ids) == 2 |
| 964 | + assert len(completion_ids) == 2 |
| 965 | + assert all(isinstance(tok, int) for tok in prompt_ids[0]) |
| 966 | + assert all(isinstance(tok, int) for tok in prompt_ids[1]) |
| 967 | + assert all(isinstance(tok, int) for tok in completion_ids[0]) |
| 968 | + assert all(isinstance(tok, int) for tok in completion_ids[1]) |
| 969 | + |
| 970 | + @classmethod |
| 971 | + def teardown_class(cls): |
| 972 | + kill_process(cls.server_process) |
0 commit comments