|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +from dataclasses import asdict |
| 4 | +from typing import NamedTuple |
| 5 | + |
| 6 | +from PIL import Image |
| 7 | + |
| 8 | +from vllm import LLM, EngineArgs, SamplingParams |
| 9 | +from vllm.assets.image import ImageAsset |
| 10 | +from vllm.config import KVTransferConfig |
| 11 | +from vllm.multimodal.utils import encode_image_base64 |
| 12 | + |
| 13 | +MODEL_NAME = "Qwen/Qwen2.5-VL-3B-Instruct" |
| 14 | + |
| 15 | +SAMPLING_PARAMS = SamplingParams(temperature=0.0, top_k=1, max_tokens=128) |
| 16 | + |
| 17 | +TEXT_PROMPTS = [ |
| 18 | + "What's in the image(s)? Around 30 words. What's special in 2nd image?", |
| 19 | + "The future of AI is", |
| 20 | +] |
| 21 | + |
| 22 | + |
| 23 | +class InputCase(NamedTuple): |
| 24 | + text: str |
| 25 | + img: list[Image] |
| 26 | + expected_len: int |
| 27 | + info: str |
| 28 | + |
| 29 | + |
| 30 | +def _check_path_len(path): |
| 31 | + """Return the latest length in path""" |
| 32 | + return len(list(path.iterdir())) |
| 33 | + |
| 34 | + |
| 35 | +def _list_path(path): |
| 36 | + """Return the list of foldername (hashes generatd) under the path""" |
| 37 | + return list(path.iterdir()) |
| 38 | + |
| 39 | + |
| 40 | +def run_test(tmp_path, processor, llm: LLM, question: str, |
| 41 | + image_urls: list[Image], expected_len: int, info: str): |
| 42 | + """ |
| 43 | + One individual test to process the prompt and output base on 1 set of input |
| 44 | + Then check if the length in the strorage path matches the expected length |
| 45 | + `info` introduces details or purpose of the individual test |
| 46 | + """ |
| 47 | + print(f"***info: {info}***") |
| 48 | + print( |
| 49 | + f"**Expected storage path length after llm generate: {expected_len}**") |
| 50 | + process_prompt(processor, llm, question, image_urls) |
| 51 | + |
| 52 | + print(f"Path matched expected length: {_check_path_len(tmp_path)}") |
| 53 | + print(f"Hashes under the storage path: {_list_path(tmp_path)}") |
| 54 | + |
| 55 | + assert _check_path_len(tmp_path) == expected_len, ( |
| 56 | + f"Expect storage path length {expected_len} ;", |
| 57 | + f"but end up {_check_path_len(tmp_path)} instead. ", f"Info: {info}") |
| 58 | + |
| 59 | + |
| 60 | +def process_prompt(processor, llm: LLM, question: str, |
| 61 | + image_urls: list[Image]): |
| 62 | + """ |
| 63 | + Form the prompt based on the text and image input, then llm generate output |
| 64 | + """ |
| 65 | + placeholders = [{ |
| 66 | + "type": "image_url", |
| 67 | + "image_url": { |
| 68 | + "url": f"data:image;base64,{encode_image_base64(image_pil)}" |
| 69 | + } |
| 70 | + } for image_pil in image_urls] |
| 71 | + |
| 72 | + messages = [ |
| 73 | + { |
| 74 | + "role": "system", |
| 75 | + "content": "You are a helpful assistant." |
| 76 | + }, |
| 77 | + { |
| 78 | + "role": "user", |
| 79 | + "content": [ |
| 80 | + *placeholders, |
| 81 | + { |
| 82 | + "type": "text", |
| 83 | + "text": question |
| 84 | + }, |
| 85 | + ], |
| 86 | + }, |
| 87 | + ] |
| 88 | + |
| 89 | + prompt = processor.apply_chat_template(messages, |
| 90 | + tokenize=False, |
| 91 | + add_generation_prompt=True) |
| 92 | + |
| 93 | + outputs = llm.generate( |
| 94 | + { |
| 95 | + "prompt": |
| 96 | + prompt, |
| 97 | + **({ |
| 98 | + "multi_modal_data": { |
| 99 | + "image": [*image_urls] |
| 100 | + } |
| 101 | + } if image_urls else {}) |
| 102 | + }, |
| 103 | + sampling_params=SAMPLING_PARAMS, |
| 104 | + ) |
| 105 | + |
| 106 | + print("-" * 50) |
| 107 | + print("Output:") |
| 108 | + for o in outputs: |
| 109 | + generated_text = o.outputs[0].text |
| 110 | + print(generated_text) |
| 111 | + print("-" * 50) |
| 112 | + |
| 113 | + |
| 114 | +def test_shared_storage_connector_hashes(tmp_path): |
| 115 | + """ |
| 116 | + Tests that SharedStorageConnector saves KV to the storage locations |
| 117 | + with proper hashes; that are unique for inputs with identical text but |
| 118 | + differnt images (same size), or same multiple images but different orders. |
| 119 | + """ |
| 120 | + # Using tmp_path as the storage path to store KV |
| 121 | + print(f"KV storage path at: {str(tmp_path)}") |
| 122 | + |
| 123 | + # Configure the SharedStorageConnector |
| 124 | + kv_transfer_config = KVTransferConfig( |
| 125 | + kv_connector="SharedStorageConnector", |
| 126 | + kv_role="kv_both", |
| 127 | + kv_connector_extra_config={"shared_storage_path": str(tmp_path)}) |
| 128 | + |
| 129 | + engine_args = EngineArgs( |
| 130 | + model=MODEL_NAME, |
| 131 | + max_model_len=8192, |
| 132 | + max_num_seqs=1, |
| 133 | + kv_transfer_config=kv_transfer_config, |
| 134 | + limit_mm_per_prompt={"image": 2}, |
| 135 | + ) |
| 136 | + |
| 137 | + # don't put this import at the top level |
| 138 | + # it will call torch.cuda.device_count() |
| 139 | + from transformers import AutoProcessor # noqa: F401 |
| 140 | + |
| 141 | + # Create processor to handle the chat prompt |
| 142 | + processor = AutoProcessor.from_pretrained(MODEL_NAME) |
| 143 | + |
| 144 | + # Prepare images for the tests |
| 145 | + # Resize to the same size to check hashes correctness |
| 146 | + image_1 = ImageAsset("stop_sign").pil_image.resize((1280, 720)) |
| 147 | + image_2 = ImageAsset("cherry_blossom").pil_image.resize((1280, 720)) |
| 148 | + |
| 149 | + # Make sure that they are not the same picture |
| 150 | + assert image_1 != image_2, "The images should not be identical" |
| 151 | + |
| 152 | + # Create the LLM instance |
| 153 | + engine_args = asdict(engine_args) |
| 154 | + llm = LLM(**engine_args) |
| 155 | + |
| 156 | + # Prepare the input cases |
| 157 | + input_cases = [ |
| 158 | + InputCase(text=TEXT_PROMPTS[0], |
| 159 | + img=[image_1], |
| 160 | + expected_len=1, |
| 161 | + info="image_1 single input the first time."), |
| 162 | + InputCase(text=TEXT_PROMPTS[0], |
| 163 | + img=[image_2], |
| 164 | + expected_len=2, |
| 165 | + info=("image_2 single input the first time. " |
| 166 | + "It is in same pixel size with image_1, yet it " |
| 167 | + "should be able to form a new unique hash.")), |
| 168 | + InputCase(text=TEXT_PROMPTS[0], |
| 169 | + img=[image_1], |
| 170 | + expected_len=2, |
| 171 | + info=("image_1 single input the 2nd time. " |
| 172 | + "It should not form aother new hash.")), |
| 173 | + InputCase(text=TEXT_PROMPTS[0], |
| 174 | + img=[image_2], |
| 175 | + expected_len=2, |
| 176 | + info=("image_2 single input the 2nd time. " |
| 177 | + "It should not form aother new hash.")), |
| 178 | + InputCase(text=TEXT_PROMPTS[0], |
| 179 | + img=[image_1, image_2], |
| 180 | + expected_len=3, |
| 181 | + info="image_1 with image_2 input the first time."), |
| 182 | + InputCase(text=TEXT_PROMPTS[0], |
| 183 | + img=[image_2, image_1], |
| 184 | + expected_len=4, |
| 185 | + info="The image order is swapped. Should form new hash."), |
| 186 | + InputCase(text=TEXT_PROMPTS[0], |
| 187 | + img=[image_1, image_2], |
| 188 | + expected_len=4, |
| 189 | + info=("[image_1, image_2] input the 2nd time. " |
| 190 | + "It should not form aother new hash.")), |
| 191 | + InputCase(text=TEXT_PROMPTS[0], |
| 192 | + img=[image_2, image_1], |
| 193 | + expected_len=4, |
| 194 | + info=("[image_2, image_1] input the 2nd time. " |
| 195 | + "It should not form aother new hash.")), |
| 196 | + InputCase(text=TEXT_PROMPTS[0], |
| 197 | + img=[], |
| 198 | + expected_len=5, |
| 199 | + info="Pure text input test as a case-control"), |
| 200 | + InputCase(text=TEXT_PROMPTS[0], |
| 201 | + img=[], |
| 202 | + expected_len=5, |
| 203 | + info="Identical pure text input as a case-control"), |
| 204 | + InputCase(text=TEXT_PROMPTS[1], |
| 205 | + img=[], |
| 206 | + expected_len=6, |
| 207 | + info="Another pure text input as a case-control"), |
| 208 | + ] |
| 209 | + |
| 210 | + # Run tests |
| 211 | + for case_id, (text, img, expected_len, info) in enumerate(input_cases): |
| 212 | + print("\n", "=" * 25, f"Below running input case: {case_id}", "=" * 25) |
| 213 | + run_test(tmp_path, processor, llm, text, img, expected_len, info) |
| 214 | + |
| 215 | + print("All tests passed successfully!") |
0 commit comments