|
3 | 3 | import os
|
4 | 4 | import random
|
5 | 5 | import sys
|
6 |
| - |
7 | 6 | import yaml
|
8 | 7 | import subprocess
|
9 | 8 | import threading
|
@@ -319,6 +318,48 @@ async def _test_multi_input_seeds(stub):
|
319 | 318 | assert 0 <= seed <= 4294967295
|
320 | 319 |
|
321 | 320 |
|
| 321 | +async def run_time_limit_test(stub, *, streaming=False, time_limit=200, min_generated_tokens=2): |
| 322 | + generation_request = pb2.GenerationRequest( |
| 323 | + text='def doit():\n' |
| 324 | + ) |
| 325 | + generation_params = pb2.Parameters( |
| 326 | + method=pb2.GREEDY, |
| 327 | + stopping=pb2.StoppingCriteria( |
| 328 | + max_new_tokens=169, |
| 329 | + min_new_tokens=169, |
| 330 | + time_limit_millis=time_limit, |
| 331 | + ) |
| 332 | + ) |
| 333 | + |
| 334 | + start = time.time_ns() |
| 335 | + if streaming: |
| 336 | + response = pb2.GenerationResponse() |
| 337 | + async for resp in stub.GenerateStream( |
| 338 | + pb2.SingleGenerationRequest( |
| 339 | + request=generation_request, |
| 340 | + params=generation_params |
| 341 | + ) |
| 342 | + ): |
| 343 | + response.generated_token_count = resp.generated_token_count |
| 344 | + response.stop_reason = resp.stop_reason |
| 345 | + else: |
| 346 | + response = await stub.Generate( |
| 347 | + pb2.BatchedGenerationRequest( |
| 348 | + requests=[generation_request], |
| 349 | + params=generation_params, |
| 350 | + ) |
| 351 | + ) |
| 352 | + # single req/resp in the batch |
| 353 | + response = response.responses[0] |
| 354 | + end = time.time_ns() |
| 355 | + |
| 356 | + assert response.stop_reason == pb2.StopReason.TIME_LIMIT |
| 357 | + # ensure that some tokens were actually generated |
| 358 | + assert min_generated_tokens <= response.generated_token_count < 100 |
| 359 | + # generating all tokens takes a few seconds |
| 360 | + assert time_limit < (end-start) / (10**6) < time_limit+300 |
| 361 | + |
| 362 | + |
322 | 363 | @pytest.mark.model("gpt2")
|
323 | 364 | @pytest.mark.extensions(".safetensors,.json")
|
324 | 365 | @pytest.mark.shards(1)
|
@@ -382,6 +423,28 @@ async def test_mt0_output_special_tokens(server_fixture, test_cases):
|
382 | 423 | await run_test_cases_async(test_cases)
|
383 | 424 |
|
384 | 425 |
|
| 426 | +# Test that the time based stopping criteria works |
| 427 | +@pytest.mark.model("bigcode/tiny_starcoder_py") |
| 428 | +@pytest.mark.extensions(".safetensors,.json") |
| 429 | +@pytest.mark.shards(1) |
| 430 | +@pytest.mark.asyncio |
| 431 | +async def test_time_limit_stopping(server_fixture): |
| 432 | + async with grpc.aio.insecure_channel('localhost:8033') as channel: |
| 433 | + stub = gpb2.GenerationServiceStub(channel) |
| 434 | + # verify server is up with metrics request |
| 435 | + response = requests.get(f'http://localhost:{3000}/metrics') |
| 436 | + assert response.status_code == 200 |
| 437 | + |
| 438 | + # batched |
| 439 | + await run_time_limit_test(stub) |
| 440 | + # one token should always be generated |
| 441 | + await run_time_limit_test(stub, time_limit=1, min_generated_tokens=1) |
| 442 | + |
| 443 | + # streaming |
| 444 | + await run_time_limit_test(stub, streaming=True) |
| 445 | + # one token should always be generated |
| 446 | + await run_time_limit_test(stub, streaming=True, time_limit=1, min_generated_tokens=1) |
| 447 | + |
385 | 448 | # Test loading when an explicit local path is provided
|
386 | 449 | def test_explicit_path():
|
387 | 450 | # Test with and without providing TRANSFORMERS_CACHE env var
|
|
0 commit comments