Skip to content

Commit 30c33bb

Browse files
tjohnson31415njhill
authored andcommitted
test: add test for the time limit stopping criteria
Signed-off-by: Travis Johnson <[email protected]>
1 parent cf7a121 commit 30c33bb

File tree

1 file changed

+64
-1
lines changed

1 file changed

+64
-1
lines changed

integration_tests/text_generation_tests/test_server.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import os
44
import random
55
import sys
6-
76
import yaml
87
import subprocess
98
import threading
@@ -319,6 +318,48 @@ async def _test_multi_input_seeds(stub):
319318
assert 0 <= seed <= 4294967295
320319

321320

321+
async def run_time_limit_test(stub, *, streaming=False, time_limit=200, min_generated_tokens=2):
322+
generation_request = pb2.GenerationRequest(
323+
text='def doit():\n'
324+
)
325+
generation_params = pb2.Parameters(
326+
method=pb2.GREEDY,
327+
stopping=pb2.StoppingCriteria(
328+
max_new_tokens=169,
329+
min_new_tokens=169,
330+
time_limit_millis=time_limit,
331+
)
332+
)
333+
334+
start = time.time_ns()
335+
if streaming:
336+
response = pb2.GenerationResponse()
337+
async for resp in stub.GenerateStream(
338+
pb2.SingleGenerationRequest(
339+
request=generation_request,
340+
params=generation_params
341+
)
342+
):
343+
response.generated_token_count = resp.generated_token_count
344+
response.stop_reason = resp.stop_reason
345+
else:
346+
response = await stub.Generate(
347+
pb2.BatchedGenerationRequest(
348+
requests=[generation_request],
349+
params=generation_params,
350+
)
351+
)
352+
# single req/resp in the batch
353+
response = response.responses[0]
354+
end = time.time_ns()
355+
356+
assert response.stop_reason == pb2.StopReason.TIME_LIMIT
357+
# ensure that some tokens were actually generated
358+
assert min_generated_tokens <= response.generated_token_count < 100
359+
# generating all tokens takes a few seconds
360+
assert time_limit < (end-start) / (10**6) < time_limit+300
361+
362+
322363
@pytest.mark.model("gpt2")
323364
@pytest.mark.extensions(".safetensors,.json")
324365
@pytest.mark.shards(1)
@@ -382,6 +423,28 @@ async def test_mt0_output_special_tokens(server_fixture, test_cases):
382423
await run_test_cases_async(test_cases)
383424

384425

426+
# Test that the time based stopping criteria works
427+
@pytest.mark.model("bigcode/tiny_starcoder_py")
428+
@pytest.mark.extensions(".safetensors,.json")
429+
@pytest.mark.shards(1)
430+
@pytest.mark.asyncio
431+
async def test_time_limit_stopping(server_fixture):
432+
async with grpc.aio.insecure_channel('localhost:8033') as channel:
433+
stub = gpb2.GenerationServiceStub(channel)
434+
# verify server is up with metrics request
435+
response = requests.get(f'http://localhost:{3000}/metrics')
436+
assert response.status_code == 200
437+
438+
# batched
439+
await run_time_limit_test(stub)
440+
# one token should always be generated
441+
await run_time_limit_test(stub, time_limit=1, min_generated_tokens=1)
442+
443+
# streaming
444+
await run_time_limit_test(stub, streaming=True)
445+
# one token should always be generated
446+
await run_time_limit_test(stub, streaming=True, time_limit=1, min_generated_tokens=1)
447+
385448
# Test loading when an explicit local path is provided
386449
def test_explicit_path():
387450
# Test with and without providing TRANSFORMERS_CACHE env var

0 commit comments

Comments
 (0)