Skip to content

Commit cb25984

Browse files
authored
Shortfin llm multi workers and multi fibers (#1280)
Enable spinning up the server with multiple workers and multiple fibers per worker to facilitate higher parallelism and concurrency in prefill/decode invocations. ## TODO: - Fix `per_fiber` isolation - I'm also attempting to enable `per_fiber` program isolation. Currently, we've just been using `per_call`. As I understand, `per_fiber` may end up being faster. I'm getting a `std::bad_cast` error when attempting to create the input device_arrays in `get_args`. Still looking into this... Otherwise good to use with `per_call` isolation. - Right now, we're in a performance sprint, and get more benefit by focusing on fixing this another time. Created [an issue](#1284) for it, and will add a `NotImplemented` error in code for this case.
1 parent 104ab43 commit cb25984

File tree

7 files changed

+141
-31
lines changed

7 files changed

+141
-31
lines changed

app_tests/integration_tests/llm/shortfin/direct_to_batcher_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class BatchConsistencyTestProcess(sf.Process):
3838
"""
3939

4040
def __init__(self, service, input_tokens, batch_sizes, max_response_length):
41-
super().__init__(fiber=service.main_fiber)
41+
super().__init__(fiber=service.fiber_pool.fibers[0])
4242
self.service = service
4343
self.input_tokens = input_tokens
4444
self.batch_sizes = batch_sizes

shortfin/python/shortfin_apps/llm/cli.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def add_input_args(parser):
3636

3737

3838
def add_service_args(parser: argparse.ArgumentParser):
39+
# TODO separate the server args from the `offline` args
3940
get_system_args(parser)
4041

4142
parser.add_argument(
@@ -70,7 +71,7 @@ def add_service_args(parser: argparse.ArgumentParser):
7071
metavar="FILE",
7172
)
7273
parser.add_argument(
73-
"--isolation",
74+
"--program_isolation",
7475
type=str,
7576
default="per_call",
7677
choices=[isolation.name.lower() for isolation in ProgramIsolation],
@@ -112,11 +113,23 @@ def add_service_args(parser: argparse.ArgumentParser):
112113
required=False,
113114
help="Temperature value to use for `offline` generation.",
114115
)
116+
parser.add_argument(
117+
"--workers_offline",
118+
type=int,
119+
default=1,
120+
help="Number of workers to use when running in `offline` mode.",
121+
)
115122
parser.add_argument(
116123
"--workers",
117124
type=int,
118125
default=1,
119-
help="Number of concurrent requests that should be running",
126+
help="Number of workers to use when running in `server` mode.",
127+
)
128+
parser.add_argument(
129+
"--fibers_per_worker",
130+
type=int,
131+
default=1,
132+
help="Number of fibers to use per worker.",
120133
)
121134
parser.add_argument(
122135
"--benchmark",
@@ -243,10 +256,10 @@ async def worker(name, queue, fiber):
243256
task.result = responder.response.result()
244257
queue.task_done()
245258

246-
logger.info(msg=f"Setting up {args.workers} workers")
259+
logger.info(msg=f"Setting up {args.workers_offline} workers")
247260
workers = []
248261
queue = asyncio.Queue()
249-
for i in range(args.workers):
262+
for i in range(args.workers_offline):
250263
name = f"worker-{i}"
251264
workerr = service.sysman.ls.create_worker(name)
252265
fiber = service.sysman.ls.create_fiber(workerr)

shortfin/python/shortfin_apps/llm/components/batcher.py

Lines changed: 70 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
import logging
88
import os
9-
from pathlib import Path
9+
10+
from dataclasses import dataclass
11+
from typing import List
1012

1113

1214
import shortfin as sf
@@ -47,6 +49,22 @@ def __init__(self, count: int = 1):
4749
self.count = count
4850

4951

52+
@dataclass
53+
class FiberPool:
54+
55+
fibers: List[sf.Fiber]
56+
idle_fibers: List[sf.Fiber]
57+
58+
def get_fiber(self):
59+
if len(self.idle_fibers) == 0:
60+
return None
61+
62+
return self.idle_fibers.pop(0)
63+
64+
def return_fiber(self, fiber: sf.Fiber):
65+
self.idle_fibers.append(fiber)
66+
67+
5068
class LlmBatcherProcess(BatcherProcess):
5169
"""This batcher provides a high-level mechanism for dispatching LLM tasks."""
5270

@@ -56,13 +74,14 @@ class LlmBatcherProcess(BatcherProcess):
5674
def __init__(
5775
self,
5876
name: str,
59-
fiber: Fiber,
77+
fiber_pool: FiberPool,
6078
page_cache: BasePagedAttentionCache,
6179
model_params: ModelParams,
6280
functions: dict[int, sf.ProgramFunction],
6381
ideal_batch_size: int,
82+
program_isolation: str,
6483
):
65-
super().__init__(fiber=fiber)
84+
super().__init__(fiber=fiber_pool.fibers[0])
6685
self.name = name
6786
self.page_cache = page_cache
6887
self.model_params = model_params
@@ -74,6 +93,9 @@ def __init__(
7493
self.page_seq_stride = self.model_params.paged_kv_cache.block_seq_stride
7594
self._current_workitems = 0
7695

96+
self.fiber_pool = fiber_pool
97+
self.program_isolation = program_isolation
98+
7799
def handle_inference_request(self, request):
78100
"""Handle an inference request."""
79101
self.pending.add(request)
@@ -115,25 +137,32 @@ async def board_flights(self):
115137
logger.info("Waiting a bit longer to fill flight")
116138
return
117139

140+
fiber = self.fiber_pool.get_fiber()
141+
if fiber is None:
142+
logger.info("Waiting for an idle fiber...")
143+
return
144+
118145
self.strobes = 0
119146
cache = self.page_cache
120147

121-
self.board(cache)
148+
self.board(cache, fiber)
122149
logger.debug("Post boarding cache state: %r", cache)
150+
if self.program_isolation != sf.ProgramIsolation.PER_FIBER:
151+
self.fiber_pool.return_fiber(fiber)
123152

124-
def make_process(self, cache: BasePagedAttentionCache):
153+
def make_process(self, cache: BasePagedAttentionCache, fiber: Fiber):
125154
...
126155

127156
def board_request(self, cache, request: LlmInferenceExecRequest):
128157
...
129158

130-
def board(self, cache: BasePagedAttentionCache):
159+
def board(self, cache: BasePagedAttentionCache, fiber: Fiber):
131160
# Fill prefill flights.
132161
pending = self.pending
133162
if len(pending) == 0:
134163
return
135164

136-
exec_process = self.make_process(cache)
165+
exec_process = self.make_process(cache, fiber)
137166

138167
for request in pending:
139168
if len(exec_process.exec_requests) >= self.ideal_batch_size:
@@ -164,26 +193,30 @@ class PrefillBatcherProcess(LlmBatcherProcess):
164193

165194
def __init__(
166195
self,
167-
fiber: Fiber,
196+
fiber_pool: FiberPool,
168197
page_cache: BasePagedAttentionCache,
169198
model_params: ModelParams,
170199
prefill_functions: dict[int, sf.ProgramFunction],
200+
program_isolation: str,
171201
):
172202
super().__init__(
173203
name="prefill",
174-
fiber=fiber,
204+
fiber_pool=fiber_pool,
175205
page_cache=page_cache,
176206
model_params=model_params,
177207
functions=prefill_functions,
178208
ideal_batch_size=max(model_params.prefill_batch_sizes),
209+
program_isolation=program_isolation,
179210
)
180211

181-
def make_process(self, cache: BasePagedAttentionCache):
212+
def make_process(self, cache: BasePagedAttentionCache, fiber: Fiber):
182213
return PrefillExecutorProcess(
183-
self.fiber,
214+
fiber,
184215
self.functions,
185216
self.page_seq_stride,
186217
cache.page_pool.page_tables,
218+
self.fiber_pool,
219+
self.program_isolation,
187220
)
188221

189222
def board_request(self, cache, request: LlmInferenceExecRequest):
@@ -216,26 +249,30 @@ class DecodeBatcherProcess(LlmBatcherProcess):
216249

217250
def __init__(
218251
self,
219-
fiber: Fiber,
252+
fiber_pool: FiberPool,
220253
page_cache: BasePagedAttentionCache,
221254
model_params: ModelParams,
222255
decode_functions: dict[int, sf.ProgramFunction],
256+
program_isolation: str,
223257
):
224258
super().__init__(
225259
name="decode",
226-
fiber=fiber,
260+
fiber_pool=fiber_pool,
227261
page_cache=page_cache,
228262
model_params=model_params,
229263
functions=decode_functions,
230264
ideal_batch_size=max(model_params.decode_batch_sizes),
265+
program_isolation=program_isolation,
231266
)
232267

233-
def make_process(self, cache: BasePagedAttentionCache):
268+
def make_process(self, cache: BasePagedAttentionCache, fiber: Fiber):
234269
return DecodeExecutorProcess(
235-
self.fiber,
270+
fiber,
236271
self.functions,
237272
self.page_seq_stride,
238273
cache.page_pool.page_tables,
274+
self.fiber_pool,
275+
self.program_isolation,
239276
)
240277

241278
def board_request(self, cache, request: LlmInferenceExecRequest):
@@ -260,13 +297,17 @@ def __init__(
260297
functions: dict[int, sf.ProgramFunction],
261298
seq_stride: int,
262299
page_tables,
300+
fiber_pool: FiberPool,
301+
program_isolation: sf.ProgramIsolation,
263302
):
264303
super().__init__(fiber=fiber)
265304
self.name = name
266305
self.seq_stride = seq_stride
267306
self.exec_requests: list[LlmInferenceExecRequest] = []
268307
self.page_tables = page_tables
269308
self.functions = functions
309+
self.fiber_pool = fiber_pool
310+
self.program_isolation = program_isolation
270311

271312
async def get_args(self, bs, device0):
272313
...
@@ -345,13 +386,17 @@ def __init__(
345386
functions: dict[int, sf.ProgramFunction],
346387
seq_stride: int,
347388
page_tables,
389+
fiber_pool: FiberPool,
390+
program_isolation: sf.ProgramIsolation,
348391
):
349392
super().__init__(
350393
name="prefill_process",
351394
fiber=fiber,
352395
functions=functions,
353396
seq_stride=seq_stride,
354397
page_tables=page_tables,
398+
fiber_pool=fiber_pool,
399+
program_isolation=program_isolation,
355400
)
356401

357402
async def get_args(self, bs, device0):
@@ -432,6 +477,9 @@ async def get_results(self, logits, req_count, device0):
432477
req.result_logits = logits_item
433478
req.done.set_success()
434479

480+
if self.program_isolation == sf.ProgramIsolation.PER_FIBER:
481+
self.fiber_pool.return_fiber(self.fiber)
482+
435483

436484
class DecodeExecutorProcess(LlmExecutorProcess):
437485
"""Executes a decode batch."""
@@ -442,13 +490,17 @@ def __init__(
442490
functions: dict[int, sf.ProgramFunction],
443491
seq_stride: int,
444492
page_tables,
493+
fiber_pool: FiberPool,
494+
isolation: sf.ProgramIsolation,
445495
):
446496
super().__init__(
447497
name="decode_process",
448498
fiber=fiber,
449499
functions=functions,
450500
seq_stride=seq_stride,
451501
page_tables=page_tables,
502+
fiber_pool=fiber_pool,
503+
program_isolation=isolation,
452504
)
453505

454506
async def get_args(self, bs, device0):
@@ -545,3 +597,6 @@ async def get_results(self, logits, req_count, device0):
545597
else:
546598
req.result_logits = logits_item
547599
req.done.set_success()
600+
601+
if self.program_isolation == sf.ProgramIsolation.PER_FIBER:
602+
self.fiber_pool.return_fiber(self.fiber)

shortfin/python/shortfin_apps/llm/components/config_struct.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,12 @@ class ServerParams:
212212
# Program isolation configuration
213213
program_isolation: str = "per_call"
214214

215+
# Number of shortfin workers to use during generation
216+
workers: int = 1
217+
218+
# Number of fibers to create per worker
219+
fibers_per_worker: int = 1
220+
215221
decode_config: DecodeConfig | None = None
216222

217223
# Device configuration

shortfin/python/shortfin_apps/llm/components/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def __init__(
142142
responder: FastAPIResponder,
143143
fiber: sf.Fiber | None = None,
144144
):
145-
super().__init__(fiber=service.main_fiber if fiber is None else fiber)
145+
super().__init__(fiber=service.fiber_pool.fibers[0] if fiber is None else fiber)
146146
self.service = service
147147
self.gen_req = gen_req
148148
self.responder = responder

0 commit comments

Comments
 (0)