Skip to content

Commit f1c3d16

Browse files
[serving] Scaffolding for llm serving. (#409)
(needs bump for IREE runtime updates)
1 parent e955627 commit f1c3d16

File tree

13 files changed

+1929
-56
lines changed

13 files changed

+1929
-56
lines changed

core/iree-requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
iree-compiler==20240207.794
2-
iree-runtime==20240207.794
1+
iree-compiler==20240215.802
2+
iree-runtime==20240215.802
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import pytest
8+
9+
from turbine_serving.framework.session import (
10+
DeviceSession,
11+
)
12+
13+
14+
@pytest.fixture
15+
def local_device_session():
16+
session = DeviceSession(uri="local-task")
17+
yield session
18+
session.shutdown()
19+
20+
21+
def test_start_shutdown_no_host_contexts(local_device_session: DeviceSession):
22+
ms = local_device_session.create_module_set("default")
23+
ms.initialize()
24+
25+
26+
def test_host_context_start_stop(local_device_session: DeviceSession):
27+
ms = local_device_session.create_module_set("default")
28+
ms.initialize()
29+
hc = ms.host_context
30+
31+
32+
def test_host_context_scheduling(local_device_session: DeviceSession):
33+
device = local_device_session.device
34+
ms = local_device_session.create_module_set("default")
35+
ms.initialize()
36+
hc = ms.host_context
37+
38+
sem = device.create_semaphore(0)
39+
40+
async def task1():
41+
print("[coro1] test_host_context_scheduling.task")
42+
await hc.on_semaphore(sem, 1, True)
43+
print("[coro1] await completed")
44+
sem.signal(2)
45+
46+
async def task2():
47+
print("[coro2] waiting for 2")
48+
await hc.on_semaphore(sem, 2, True)
49+
sem.fail("Fail from task2")
50+
51+
f1 = hc.run_concurrent(task1())
52+
f2 = hc.run_concurrent(task2())
53+
sem.signal(1)
54+
print("[main] Waiting for semaphore")
55+
56+
# Ensure task completion. Important to consume to ensure that exceptions
57+
# propagate.
58+
f1.result()
59+
f2.result()
60+
61+
print("[main] Waiting on semaphore payload 3")
62+
with pytest.raises(Exception, match="Fail from task2"):
63+
sem.wait(3)

serving/tests/api_server_test.py renamed to serving/tests/llm/api_server_test.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def __init__(self, args):
2121
[
2222
sys.executable,
2323
"-m",
24-
"turbine_serving.llm.entrypoints.api_server",
24+
"turbine_serving.llm.api.rest_server",
25+
"--testing-mock-service",
2526
]
2627
+ args,
2728
env=env,
@@ -39,9 +40,9 @@ def _wait_for_ready(self):
3940
except Exception as e:
4041
if self.process.poll() is not None:
4142
raise RuntimeError("API server processs terminated") from e
42-
time.sleep(0.25)
43+
time.sleep(1.0)
4344
if time.time() - start > 30:
44-
raise RuntimeError("Timeout waiting for server start") from e
45+
raise RuntimeError("Timeout waiting for server start")
4546

4647
def __del__(self):
4748
try:
@@ -59,5 +60,30 @@ def server():
5960
yield runner
6061

6162

62-
def test_basic(server: ServerRunner):
63+
def test_health(server: ServerRunner):
64+
# Health check is part of getting the fixture.
6365
...
66+
67+
68+
def test_generate_non_streaming(server: ServerRunner):
69+
resp = requests.post(
70+
f"{server.url}/generate",
71+
json={
72+
"prompt": "Hi Bob",
73+
},
74+
)
75+
resp.raise_for_status()
76+
d = resp.json()
77+
assert d["text"] == "Hi Bob", repr(d)
78+
79+
80+
def test_generate_streaming(server: ServerRunner):
81+
resp = requests.post(
82+
f"{server.url}/generate", json={"prompt": "Hi Bob!", "stream": True}
83+
)
84+
resp.raise_for_status()
85+
full_contents = resp.content
86+
expected_contents = b'{"text": "Hi Bob!"}\x00' * 5
87+
assert (
88+
full_contents == expected_contents
89+
), f"Expected {expected_contents!r} vs {full_contents!r}"

serving/tests/llm/service_v1_test.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import pytest
8+
9+
from iree.runtime import ( # type: ignore
10+
HalElementType,
11+
)
12+
13+
from turbine_serving.framework.session import DeviceSession
14+
from turbine_serving.llm.config import (
15+
CacheParams,
16+
ModelParams,
17+
ServiceParams,
18+
)
19+
20+
from turbine_serving.llm.service import (
21+
GenerateRequest,
22+
GenerateResponsePart,
23+
)
24+
25+
from turbine_serving.llm.attn_block_cache import (
26+
create_attn_block_cache_module,
27+
AttnBlockCache,
28+
)
29+
30+
from turbine_serving.llm.impl.service_v1 import (
31+
GenerateServiceV1,
32+
)
33+
34+
from turbine_serving.llm.testing.fake_v1_module import (
35+
create_fake_module,
36+
)
37+
38+
39+
@pytest.fixture
40+
def cache_params(model_params: ModelParams) -> CacheParams:
41+
return CacheParams(model=model_params, device_block_count=128, block_pos_stride=16)
42+
43+
44+
@pytest.fixture
45+
def model_params() -> ModelParams:
46+
return ModelParams(
47+
module_name="AwesomeLLM",
48+
module_abi_version=1,
49+
attn_dtype=HalElementType.FLOAT_16,
50+
max_seq_len=128,
51+
transformer_block_count=32,
52+
attn_head_count=32,
53+
attn_head_dim=128,
54+
prefill_batch_sizes=[1, 4, 16],
55+
decode_batch_sizes=[1, 4, 16],
56+
)
57+
58+
59+
@pytest.fixture
60+
def uninitialized_session(model_params: ModelParams):
61+
from iree.runtime._binding import disable_leak_checker # type: ignore
62+
63+
disable_leak_checker()
64+
session = DeviceSession(uri="local-task", queue_count=2)
65+
yield session
66+
session.shutdown()
67+
del session
68+
69+
70+
@pytest.fixture
71+
def attn_block_cache(
72+
uninitialized_session: DeviceSession, cache_params: CacheParams
73+
) -> AttnBlockCache:
74+
return AttnBlockCache(uninitialized_session, cache_params)
75+
76+
77+
@pytest.fixture
78+
def session(
79+
model_params: ModelParams,
80+
uninitialized_session: DeviceSession,
81+
attn_block_cache: AttnBlockCache,
82+
):
83+
session = uninitialized_session
84+
lms = session.create_module_set("AwesomeLLM", context_count=1)
85+
lms.add(
86+
create_attn_block_cache_module(attn_block_cache),
87+
create_fake_module(session.device, "AwesomeLLM", model_params=model_params),
88+
)
89+
lms.initialize()
90+
return session
91+
92+
93+
@pytest.fixture
94+
def service(
95+
session: DeviceSession,
96+
cache_params: CacheParams,
97+
model_params: ModelParams,
98+
attn_block_cache: AttnBlockCache,
99+
):
100+
params = ServiceParams(cache=cache_params, model=model_params)
101+
return GenerateServiceV1(session=session, params=params, cache=attn_block_cache)
102+
103+
104+
def test_single(service: GenerateServiceV1):
105+
state = service.start()
106+
107+
async def task():
108+
await state.set_sequences(
109+
requests=[
110+
GenerateRequest(
111+
"1",
112+
"hello, tell me a story",
113+
[3, 4, 5, 12, 23, 88, 10, 2, 5, 9, 12, 13, 99, 56, 33, 124, 73],
114+
),
115+
GenerateRequest("2", "goodbye", [9, 10]),
116+
]
117+
)
118+
guarded_outputs = await state.prefill()
119+
prefill_ids = await guarded_outputs.resolve(state.host_context)
120+
print(
121+
"PREFILL IDS:",
122+
prefill_ids,
123+
":\n",
124+
prefill_ids.map().asarray(
125+
prefill_ids.shape, HalElementType.map_to_dtype(prefill_ids.element_type)
126+
),
127+
)
128+
await state.recycle()
129+
130+
state.host_context.run_sync(task())
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import logging
8+
import os
9+
import sys
10+
11+
12+
# Whether debug assertions are disabled.
13+
NDEBUG: bool = False
14+
15+
_default_log_level = os.getenv("TURBINE_LOG_LEVEL", "DEBUG")
16+
17+
18+
class DefaultFormatter(logging.Formatter):
19+
def __init__(self):
20+
super().__init__(
21+
"%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s",
22+
"%m-%d %H:%M:%S",
23+
)
24+
25+
26+
def _setup_logger():
27+
root_logger = logging.getLogger("turbine_serving")
28+
root_logger.setLevel(logging.DEBUG)
29+
default_handler = logging.StreamHandler(sys.stderr)
30+
default_handler.flush = sys.stderr.flush
31+
default_handler.setLevel(_default_log_level)
32+
default_handler.setFormatter(DefaultFormatter())
33+
root_logger.addHandler(default_handler)
34+
root_logger.propagate = False
35+
return root_logger, default_handler
36+
37+
38+
root_logger, default_handler = _setup_logger()
39+
40+
logging.getLogger("asyncio").addHandler(default_handler)
41+
42+
43+
def get_logger(name: str):
44+
logger = logging.getLogger(name)
45+
logger.setLevel(_default_log_level)
46+
logger.addHandler(default_handler)
47+
logger.propagate = False
48+
return logger

0 commit comments

Comments
 (0)