66
77import logging
88import os
9- from pathlib import Path
9+
10+ from dataclasses import dataclass
11+ from typing import List
1012
1113
1214import shortfin as sf
@@ -47,6 +49,22 @@ def __init__(self, count: int = 1):
4749 self .count = count
4850
4951
52+ @dataclass
53+ class FiberPool :
54+
55+ fibers : List [sf .Fiber ]
56+ idle_fibers : List [sf .Fiber ]
57+
58+ def get_fiber (self ):
59+ if len (self .idle_fibers ) == 0 :
60+ return None
61+
62+ return self .idle_fibers .pop (0 )
63+
64+ def return_fiber (self , fiber : sf .Fiber ):
65+ self .idle_fibers .append (fiber )
66+
67+
5068class LlmBatcherProcess (BatcherProcess ):
5169 """This batcher provides a high-level mechanism for dispatching LLM tasks."""
5270
@@ -56,13 +74,14 @@ class LlmBatcherProcess(BatcherProcess):
5674 def __init__ (
5775 self ,
5876 name : str ,
59- fiber : Fiber ,
77+ fiber_pool : FiberPool ,
6078 page_cache : BasePagedAttentionCache ,
6179 model_params : ModelParams ,
6280 functions : dict [int , sf .ProgramFunction ],
6381 ideal_batch_size : int ,
82+ program_isolation : str ,
6483 ):
65- super ().__init__ (fiber = fiber )
84+ super ().__init__ (fiber = fiber_pool . fibers [ 0 ] )
6685 self .name = name
6786 self .page_cache = page_cache
6887 self .model_params = model_params
@@ -74,6 +93,9 @@ def __init__(
7493 self .page_seq_stride = self .model_params .paged_kv_cache .block_seq_stride
7594 self ._current_workitems = 0
7695
96+ self .fiber_pool = fiber_pool
97+ self .program_isolation = program_isolation
98+
7799 def handle_inference_request (self , request ):
78100 """Handle an inference request."""
79101 self .pending .add (request )
@@ -115,25 +137,32 @@ async def board_flights(self):
115137 logger .info ("Waiting a bit longer to fill flight" )
116138 return
117139
140+ fiber = self .fiber_pool .get_fiber ()
141+ if fiber is None :
142+ logger .info ("Waiting for an idle fiber..." )
143+ return
144+
118145 self .strobes = 0
119146 cache = self .page_cache
120147
121- self .board (cache )
148+ self .board (cache , fiber )
122149 logger .debug ("Post boarding cache state: %r" , cache )
150+ if self .program_isolation != sf .ProgramIsolation .PER_FIBER :
151+ self .fiber_pool .return_fiber (fiber )
123152
124- def make_process (self , cache : BasePagedAttentionCache ):
153+ def make_process (self , cache : BasePagedAttentionCache , fiber : Fiber ):
125154 ...
126155
127156 def board_request (self , cache , request : LlmInferenceExecRequest ):
128157 ...
129158
130- def board (self , cache : BasePagedAttentionCache ):
159+ def board (self , cache : BasePagedAttentionCache , fiber : Fiber ):
131160 # Fill prefill flights.
132161 pending = self .pending
133162 if len (pending ) == 0 :
134163 return
135164
136- exec_process = self .make_process (cache )
165+ exec_process = self .make_process (cache , fiber )
137166
138167 for request in pending :
139168 if len (exec_process .exec_requests ) >= self .ideal_batch_size :
@@ -164,26 +193,30 @@ class PrefillBatcherProcess(LlmBatcherProcess):
164193
165194 def __init__ (
166195 self ,
167- fiber : Fiber ,
196+ fiber_pool : FiberPool ,
168197 page_cache : BasePagedAttentionCache ,
169198 model_params : ModelParams ,
170199 prefill_functions : dict [int , sf .ProgramFunction ],
200+ program_isolation : str ,
171201 ):
172202 super ().__init__ (
173203 name = "prefill" ,
174- fiber = fiber ,
204+ fiber_pool = fiber_pool ,
175205 page_cache = page_cache ,
176206 model_params = model_params ,
177207 functions = prefill_functions ,
178208 ideal_batch_size = max (model_params .prefill_batch_sizes ),
209+ program_isolation = program_isolation ,
179210 )
180211
181- def make_process (self , cache : BasePagedAttentionCache ):
212+ def make_process (self , cache : BasePagedAttentionCache , fiber : Fiber ):
182213 return PrefillExecutorProcess (
183- self . fiber ,
214+ fiber ,
184215 self .functions ,
185216 self .page_seq_stride ,
186217 cache .page_pool .page_tables ,
218+ self .fiber_pool ,
219+ self .program_isolation ,
187220 )
188221
189222 def board_request (self , cache , request : LlmInferenceExecRequest ):
@@ -216,26 +249,30 @@ class DecodeBatcherProcess(LlmBatcherProcess):
216249
217250 def __init__ (
218251 self ,
219- fiber : Fiber ,
252+ fiber_pool : FiberPool ,
220253 page_cache : BasePagedAttentionCache ,
221254 model_params : ModelParams ,
222255 decode_functions : dict [int , sf .ProgramFunction ],
256+ program_isolation : str ,
223257 ):
224258 super ().__init__ (
225259 name = "decode" ,
226- fiber = fiber ,
260+ fiber_pool = fiber_pool ,
227261 page_cache = page_cache ,
228262 model_params = model_params ,
229263 functions = decode_functions ,
230264 ideal_batch_size = max (model_params .decode_batch_sizes ),
265+ program_isolation = program_isolation ,
231266 )
232267
233- def make_process (self , cache : BasePagedAttentionCache ):
268+ def make_process (self , cache : BasePagedAttentionCache , fiber : Fiber ):
234269 return DecodeExecutorProcess (
235- self . fiber ,
270+ fiber ,
236271 self .functions ,
237272 self .page_seq_stride ,
238273 cache .page_pool .page_tables ,
274+ self .fiber_pool ,
275+ self .program_isolation ,
239276 )
240277
241278 def board_request (self , cache , request : LlmInferenceExecRequest ):
@@ -260,13 +297,17 @@ def __init__(
260297 functions : dict [int , sf .ProgramFunction ],
261298 seq_stride : int ,
262299 page_tables ,
300+ fiber_pool : FiberPool ,
301+ program_isolation : sf .ProgramIsolation ,
263302 ):
264303 super ().__init__ (fiber = fiber )
265304 self .name = name
266305 self .seq_stride = seq_stride
267306 self .exec_requests : list [LlmInferenceExecRequest ] = []
268307 self .page_tables = page_tables
269308 self .functions = functions
309+ self .fiber_pool = fiber_pool
310+ self .program_isolation = program_isolation
270311
271312 async def get_args (self , bs , device0 ):
272313 ...
@@ -345,13 +386,17 @@ def __init__(
345386 functions : dict [int , sf .ProgramFunction ],
346387 seq_stride : int ,
347388 page_tables ,
389+ fiber_pool : FiberPool ,
390+ program_isolation : sf .ProgramIsolation ,
348391 ):
349392 super ().__init__ (
350393 name = "prefill_process" ,
351394 fiber = fiber ,
352395 functions = functions ,
353396 seq_stride = seq_stride ,
354397 page_tables = page_tables ,
398+ fiber_pool = fiber_pool ,
399+ program_isolation = program_isolation ,
355400 )
356401
357402 async def get_args (self , bs , device0 ):
@@ -432,6 +477,9 @@ async def get_results(self, logits, req_count, device0):
432477 req .result_logits = logits_item
433478 req .done .set_success ()
434479
480+ if self .program_isolation == sf .ProgramIsolation .PER_FIBER :
481+ self .fiber_pool .return_fiber (self .fiber )
482+
435483
436484class DecodeExecutorProcess (LlmExecutorProcess ):
437485 """Executes a decode batch."""
@@ -442,13 +490,17 @@ def __init__(
442490 functions : dict [int , sf .ProgramFunction ],
443491 seq_stride : int ,
444492 page_tables ,
493+ fiber_pool : FiberPool ,
494+ isolation : sf .ProgramIsolation ,
445495 ):
446496 super ().__init__ (
447497 name = "decode_process" ,
448498 fiber = fiber ,
449499 functions = functions ,
450500 seq_stride = seq_stride ,
451501 page_tables = page_tables ,
502+ fiber_pool = fiber_pool ,
503+ program_isolation = isolation ,
452504 )
453505
454506 async def get_args (self , bs , device0 ):
@@ -545,3 +597,6 @@ async def get_results(self, logits, req_count, device0):
545597 else :
546598 req .result_logits = logits_item
547599 req .done .set_success ()
600+
601+ if self .program_isolation == sf .ProgramIsolation .PER_FIBER :
602+ self .fiber_pool .return_fiber (self .fiber )
0 commit comments