|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 |
|
| 4 | +import io |
4 | 5 | from dataclasses import dataclass |
5 | 6 | from typing import Optional |
6 | 7 | from unittest.mock import AsyncMock, MagicMock |
7 | 8 |
|
| 9 | +import pybase64 |
8 | 10 | import pytest |
| 11 | +import torch |
9 | 12 |
|
10 | 13 | from vllm.entrypoints.renderer import CompletionRenderer |
| 14 | +from vllm.inputs.data import is_embeds_prompt |
11 | 15 |
|
12 | 16 |
|
13 | 17 | @dataclass |
@@ -178,3 +182,132 @@ async def test_no_tokenizer_for_text(self, mock_model_config): |
178 | 182 | with pytest.raises(ValueError, match="No tokenizer available"): |
179 | 183 | await renderer_no_tokenizer.render_prompt( |
180 | 184 | prompt_or_prompts="Hello world", max_length=100) |
| 185 | + |
| 186 | + @pytest.mark.asyncio |
| 187 | + async def test_token_input_with_needs_detokenization( |
| 188 | + self, renderer, mock_async_tokenizer): |
| 189 | + # When needs_detokenization=True for token inputs, renderer should |
| 190 | + # use the async tokenizer to decode and include the original text |
| 191 | + # in the returned prompt object. |
| 192 | + mock_async_tokenizer.decode = AsyncMock(return_value="decoded text") |
| 193 | + renderer.async_tokenizer_pool[ |
| 194 | + renderer.tokenizer] = mock_async_tokenizer |
| 195 | + |
| 196 | + tokens = [1, 2, 3, 4] |
| 197 | + results = await renderer.render_prompt( |
| 198 | + prompt_or_prompts=tokens, |
| 199 | + needs_detokenization=True, |
| 200 | + ) |
| 201 | + |
| 202 | + assert len(results) == 1 |
| 203 | + assert results[0]["prompt_token_ids"] == tokens |
| 204 | + assert results[0]["prompt"] == "decoded text" |
| 205 | + mock_async_tokenizer.decode.assert_awaited_once() |
| 206 | + |
| 207 | + |
| 208 | +class TestRenderEmbedPrompt: |
| 209 | + |
| 210 | + def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes: |
| 211 | + """Helper to create base64-encoded tensor bytes""" |
| 212 | + buffer = io.BytesIO() |
| 213 | + torch.save(tensor, buffer) |
| 214 | + buffer.seek(0) |
| 215 | + return pybase64.b64encode(buffer.read()) |
| 216 | + |
| 217 | + @pytest.mark.asyncio |
| 218 | + async def test_single_prompt_embed(self, renderer): |
| 219 | + # Create a test tensor |
| 220 | + test_tensor = torch.randn(10, 768, dtype=torch.float32) |
| 221 | + embed_bytes = self._create_test_embed_bytes(test_tensor) |
| 222 | + |
| 223 | + results = await renderer.render_prompt_and_embeds( |
| 224 | + prompt_embeds=embed_bytes, cache_salt="test_salt") |
| 225 | + |
| 226 | + assert len(results) == 1 |
| 227 | + assert is_embeds_prompt(results[0]) |
| 228 | + assert torch.allclose(results[0]["prompt_embeds"], test_tensor) |
| 229 | + assert results[0]["cache_salt"] == "test_salt" |
| 230 | + |
| 231 | + @pytest.mark.asyncio |
| 232 | + async def test_multiple_prompt_embeds(self, renderer): |
| 233 | + # Create multiple test tensors |
| 234 | + test_tensors = [ |
| 235 | + torch.randn(8, 512, dtype=torch.float32), |
| 236 | + torch.randn(12, 512, dtype=torch.float32), |
| 237 | + ] |
| 238 | + embed_bytes_list = [ |
| 239 | + self._create_test_embed_bytes(t) for t in test_tensors |
| 240 | + ] |
| 241 | + |
| 242 | + results = await renderer.render_prompt_and_embeds( |
| 243 | + prompt_embeds=embed_bytes_list) |
| 244 | + |
| 245 | + assert len(results) == 2 |
| 246 | + for i, result in enumerate(results): |
| 247 | + assert is_embeds_prompt(result) |
| 248 | + assert torch.allclose(result["prompt_embeds"], test_tensors[i]) |
| 249 | + |
| 250 | + @pytest.mark.asyncio |
| 251 | + async def test_prompt_embed_truncation(self, renderer): |
| 252 | + # Create tensor with more tokens than truncation limit |
| 253 | + test_tensor = torch.randn(20, 768, dtype=torch.float32) |
| 254 | + embed_bytes = self._create_test_embed_bytes(test_tensor) |
| 255 | + |
| 256 | + results = await renderer.render_prompt_and_embeds( |
| 257 | + prompt_embeds=embed_bytes, truncate_prompt_tokens=10) |
| 258 | + |
| 259 | + assert len(results) == 1 |
| 260 | + # Should keep last 10 tokens |
| 261 | + expected = test_tensor[-10:] |
| 262 | + assert torch.allclose(results[0]["prompt_embeds"], expected) |
| 263 | + |
| 264 | + @pytest.mark.asyncio |
| 265 | + async def test_prompt_embed_different_dtypes(self, renderer): |
| 266 | + # Test different supported dtypes |
| 267 | + dtypes = [torch.float32, torch.float16, torch.bfloat16] |
| 268 | + |
| 269 | + for dtype in dtypes: |
| 270 | + test_tensor = torch.randn(5, 256, dtype=dtype) |
| 271 | + embed_bytes = self._create_test_embed_bytes(test_tensor) |
| 272 | + |
| 273 | + results = await renderer.render_prompt_and_embeds( |
| 274 | + prompt_embeds=embed_bytes) |
| 275 | + |
| 276 | + assert len(results) == 1 |
| 277 | + assert results[0]["prompt_embeds"].dtype == dtype |
| 278 | + |
| 279 | + @pytest.mark.asyncio |
| 280 | + async def test_prompt_embed_squeeze_batch_dim(self, renderer): |
| 281 | + # Test tensor with batch dimension gets squeezed |
| 282 | + test_tensor = torch.randn(1, 10, 768, dtype=torch.float32) |
| 283 | + embed_bytes = self._create_test_embed_bytes(test_tensor) |
| 284 | + |
| 285 | + results = await renderer.render_prompt_and_embeds( |
| 286 | + prompt_embeds=embed_bytes) |
| 287 | + |
| 288 | + assert len(results) == 1 |
| 289 | + # Should be squeezed to 2D |
| 290 | + assert results[0]["prompt_embeds"].shape == (10, 768) |
| 291 | + |
| 292 | + @pytest.mark.asyncio |
| 293 | + async def test_both_prompts_and_embeds(self, renderer, |
| 294 | + mock_async_tokenizer): |
| 295 | + # Set up text tokenization |
| 296 | + mock_async_tokenizer.return_value = MockTokenizerResult( |
| 297 | + [101, 102, 103]) |
| 298 | + renderer.async_tokenizer_pool[ |
| 299 | + renderer.tokenizer] = mock_async_tokenizer |
| 300 | + |
| 301 | + # Create embed |
| 302 | + test_tensor = torch.randn(5, 256, dtype=torch.float32) |
| 303 | + embed_bytes = self._create_test_embed_bytes(test_tensor) |
| 304 | + |
| 305 | + results = await renderer.render_prompt_and_embeds( |
| 306 | + prompt_or_prompts="Hello world", prompt_embeds=embed_bytes) |
| 307 | + |
| 308 | + assert len(results) == 2 |
| 309 | + # First should be embed prompt |
| 310 | + assert is_embeds_prompt(results[0]) |
| 311 | + # Second should be tokens prompt |
| 312 | + assert "prompt_token_ids" in results[1] |
| 313 | + assert results[1]["prompt_token_ids"] == [101, 102, 103] |
0 commit comments