@@ -10,9 +10,13 @@ const Epsilon = float32(0.001)
1010// fraction of maximum server throughput to provide stability (running this fraction below the maximum)
1111const StabilitySafetyFraction = float32 (0.1 )
1212
13+ // maximum number of tokens per batch (iteration)
14+ const DefaultMaxNumTokens = 8192
15+
1316// Analyzer of inference server queue
1417type QueueAnalyzer struct {
1518 MaxBatchSize int // maximum batch size
19+ MaxNumTokens int // maximum number of tokens per batch
1620 MaxQueueSize int // maximum queue size
1721 ServiceParms * ServiceParms // request processing parameters
1822 RequestSize * RequestSize // number of input and output tokens per request
@@ -23,32 +27,23 @@ type QueueAnalyzer struct {
2327// queue configuration parameters
2428type Configuration struct {
2529 MaxBatchSize int // maximum batch size (limit on the number of requests concurrently receiving service >0)
30+ MaxNumTokens int // maximum number of tokens per batch (limit on the number of tokens per batch >0)
2631 MaxQueueSize int // maximum queue size (limit on the number of requests queued for servive >=0)
2732 ServiceParms * ServiceParms // request processing parameters
2833}
2934
30- // request processing parameters
35+ // request processing parameters:
36+ // iterationTime = alpha + beta * computeTime + gamma * memoryAccessTime
3137type ServiceParms struct {
32- Prefill * PrefillParms // parameters to calculate prefill time
33- Decode * DecodeParms // parameters to calculate decode time
34- }
35-
36- // prefill time = gamma + delta * inputTokens * batchSize (msec); inputTokens > 0
37- type PrefillParms struct {
38- Gamma float32 // base
39- Delta float32 // slope
40- }
41-
42- // decode time = alpha + beta * batchSize (msec); batchSize > 0
43- type DecodeParms struct {
4438 Alpha float32 // base
45- Beta float32 // slope
39+ Beta float32 // slope for compute time
40+ Gamma float32 // slope for memory access time
4641}
4742
4843// request tokens data
4944type RequestSize struct {
50- AvgInputTokens int // average number of input tokens per request
51- AvgOutputTokens int // average number of output tokens per request
45+ AvgInputTokens float32 // average number of input tokens per request
46+ AvgOutputTokens float32 // average number of output tokens per request
5247}
5348
5449// range of request rates (requests/sec)
@@ -65,6 +60,7 @@ type AnalysisMetrics struct {
6560 AvgNumInServ float32 // average number of requests in service
6661 AvgPrefillTime float32 // average request prefill time (msec)
6762 AvgTokenTime float32 // average token decode time (msec)
63+ AvgTTFT float32 // average time to first token (msec)
6864 MaxRate float32 // maximum throughput (requests/sec)
6965 Rho float32 // utilization
7066}
@@ -96,35 +92,32 @@ func NewQueueAnalyzer(qConfig *Configuration, requestSize *RequestSize) (*QueueA
9692}
9793
9894// build queueing model using service rates, leaving arrival rate as parameter
99- func BuildModel (qConfig * Configuration , requestSize * RequestSize ) (modelData * QueueAnalyzer ) {
100- parms := qConfig .ServiceParms
95+ func BuildModel (c * Configuration , r * RequestSize ) (modelData * QueueAnalyzer ) {
96+ parms := c .ServiceParms
10197
10298 // calculate state-dependent service rate
103- servRate := make ([]float32 , qConfig .MaxBatchSize )
104- for n := 1 ; n <= qConfig .MaxBatchSize ; n ++ {
105- prefillTime := parms .Prefill .PrefillTime (requestSize .AvgInputTokens , float32 (n ))
106- numDecode := requestSize .AvgOutputTokens - 1 // number of decodes (one per output token except the first)
107- // special case: allow one decode in case of decode only and one output token
108- if requestSize .AvgInputTokens == 0 && requestSize .AvgOutputTokens == 1 {
109- numDecode = 1
110- }
111- decodeTime := float32 (numDecode ) * parms .Decode .DecodeTime (float32 (n ))
99+ servRate := make ([]float32 , c .MaxBatchSize )
100+ for n := 1 ; n <= c .MaxBatchSize ; n ++ {
101+ prefillTime := parms .PrefillTime (r , float32 (n ))
102+ decodeTime := r .AvgOutputTokens * parms .DecodeTime (r , float32 (n ))
112103 servRate [n - 1 ] = float32 (n ) / (prefillTime + decodeTime )
113104 }
114105
115106 // set and check limits
116107 lambdaMin := servRate [0 ] * Epsilon
117- lambdaMax := servRate [qConfig .MaxBatchSize - 1 ] * (1 - Epsilon )
108+ lambdaMax := servRate [c .MaxBatchSize - 1 ] * (1 - Epsilon )
118109 rateRange := & RateRange {Min : lambdaMin * 1000 , Max : lambdaMax * 1000 }
119110
120111 // create and solve model
121- occupancyUpperBound := qConfig .MaxQueueSize + qConfig .MaxBatchSize
112+ occupancyUpperBound := c .MaxQueueSize + c .MaxBatchSize
122113 model := NewMM1ModelStateDependent (occupancyUpperBound , servRate )
114+
123115 return & QueueAnalyzer {
124- MaxBatchSize : qConfig .MaxBatchSize ,
125- MaxQueueSize : qConfig .MaxQueueSize ,
116+ MaxBatchSize : c .MaxBatchSize ,
117+ MaxNumTokens : c .MaxNumTokens ,
118+ MaxQueueSize : c .MaxQueueSize ,
126119 ServiceParms : parms ,
127- RequestSize : requestSize ,
120+ RequestSize : r ,
128121 Model : model ,
129122 RateRange : rateRange ,
130123 }
@@ -151,10 +144,9 @@ func (qa *QueueAnalyzer) Analyze(requestRate float32) (metrics *AnalysisMetrics,
151144
152145 // get statistics
153146 avgNumInServ := model .GetAvgNumInServers ()
154-
155- effConc := EffectiveConcurrency (model .GetAvgServTime (), qa .ServiceParms , qa .RequestSize , qa .MaxBatchSize )
156- prefillTime := qa .ServiceParms .Prefill .PrefillTime (qa .RequestSize .AvgInputTokens , effConc )
157- tokenTime := qa .ServiceParms .Decode .DecodeTime (effConc )
147+ avgPrefillTime := qa .ServiceParms .PrefillTime (qa .RequestSize , avgNumInServ )
148+ avgDecodeTime := (model .GetAvgServTime () - avgPrefillTime ) / qa .RequestSize .AvgOutputTokens
149+ avgTTFT := model .GetAvgWaitTime () + avgPrefillTime + avgDecodeTime
158150
159151 rho := avgNumInServ / float32 (qa .MaxBatchSize )
160152 rho = min (max (rho , 0 ), 1 )
@@ -165,18 +157,22 @@ func (qa *QueueAnalyzer) Analyze(requestRate float32) (metrics *AnalysisMetrics,
165157 AvgRespTime : model .GetAvgRespTime (),
166158 AvgWaitTime : model .GetAvgWaitTime (),
167159 AvgNumInServ : avgNumInServ ,
168- AvgPrefillTime : prefillTime ,
169- AvgTokenTime : tokenTime ,
160+ AvgPrefillTime : avgPrefillTime ,
161+ AvgTokenTime : avgDecodeTime ,
162+ AvgTTFT : avgTTFT ,
170163 MaxRate : rateRange .Max ,
171164 Rho : rho ,
172165 }
173166 return metrics , nil
174167}
175168
176- // global variables used by eval functions, to be set before calling eval function
177- var evalRequestSize * RequestSize // number of input and output tokens per request
178- var evalServiceParms * ServiceParms // request processing parameters for prefill and decode stages
179- var evalMaxBatchSize int // max batch size
169+ // model and parameters used in functional evaluation
170+ type EvalFuncData struct {
171+ model * MM1ModelStateDependent // queueing model
172+ requestSize * RequestSize // number of input and output tokens per request
173+ serviceParms * ServiceParms // request processing parameters for prefill and decode stages
174+ maxBatchSize int // max batch size
175+ }
180176
181177// evaluate max request rates to achieve a given target performance, returns
182178// - max request rates
@@ -193,18 +189,19 @@ func (qa *QueueAnalyzer) Size(targetPerf *TargetPerf) (targetRate *TargetRate, m
193189 lambdaMin := qa .RateRange .Min / 1000
194190 lambdaMax := qa .RateRange .Max / 1000
195191
196- // set global variables for model and parameters used in functional evaluation
197- Model = qa .Model
198- evalRequestSize = qa .RequestSize
199- evalServiceParms = qa .ServiceParms
200- evalMaxBatchSize = qa .MaxBatchSize
201-
192+ // indicator value returned by binary search
202193 var ind int
203194
204195 // find max rate to achieve target TTFT time
205196 lambdaStarTTFT := lambdaMax
206197 if targetTTFT > 0 {
207- lambdaStarTTFT , ind , err = BinarySearch (lambdaMin , lambdaMax , targetTTFT , EvalTTFT )
198+ evalTTF := EvalTTFT (& EvalFuncData {
199+ model : qa .Model ,
200+ requestSize : qa .RequestSize ,
201+ serviceParms : qa .ServiceParms ,
202+ maxBatchSize : qa .MaxBatchSize ,
203+ })
204+ lambdaStarTTFT , ind , err = BinarySearch (lambdaMin , lambdaMax , targetTTFT , evalTTF )
208205 if ind < 0 {
209206 err = fmt .Errorf ("target is below the bounded region" )
210207 }
@@ -217,7 +214,13 @@ func (qa *QueueAnalyzer) Size(targetPerf *TargetPerf) (targetRate *TargetRate, m
217214 // find max rate to achieve target ITL time
218215 lambdaStarITL := lambdaMax
219216 if targetITL > 0 {
220- lambdaStarITL , ind , err = BinarySearch (lambdaMin , lambdaMax , targetITL , EvalITL )
217+ evalITL := EvalITL (& EvalFuncData {
218+ model : qa .Model ,
219+ requestSize : qa .RequestSize ,
220+ serviceParms : qa .ServiceParms ,
221+ maxBatchSize : qa .MaxBatchSize ,
222+ })
223+ lambdaStarITL , ind , err = BinarySearch (lambdaMin , lambdaMax , targetITL , evalITL )
221224 if ind < 0 {
222225 err = fmt .Errorf ("target is below the bounded region" )
223226 }
@@ -247,133 +250,59 @@ func (qa *QueueAnalyzer) Size(targetPerf *TargetPerf) (targetRate *TargetRate, m
247250 }
248251
249252 achieved = & TargetPerf {
250- TargetTTFT : metrics .AvgWaitTime + metrics . AvgPrefillTime ,
253+ TargetTTFT : metrics .AvgTTFT ,
251254 TargetITL : metrics .AvgTokenTime ,
252- TargetTPS : metrics .Throughput * float32 ( qa .RequestSize .AvgOutputTokens ) ,
255+ TargetTPS : metrics .Throughput * qa .RequestSize .AvgOutputTokens ,
253256 }
254257 return targetRate , metrics , achieved , nil
255258}
256259
257- func (p * PrefillParms ) PrefillTime (avgInputTokens int , batchSize float32 ) float32 {
258- if avgInputTokens == 0 {
260+ // Average iteration time as a function of the batch size T(n)
261+ func (p * ServiceParms ) IterationTime (r * RequestSize , batchSize float32 ) float32 {
262+ tokensCompute := (r .AvgInputTokens + r .AvgOutputTokens ) / (r .AvgOutputTokens + 1 )
263+ tokensMemory := r .AvgInputTokens + r .AvgOutputTokens / 2
264+ return p .Alpha + batchSize * (p .Beta * tokensCompute + p .Gamma * tokensMemory )
265+ }
266+
267+ // Average prefill time as a function of the batch size
268+ func (p * ServiceParms ) PrefillTime (r * RequestSize , batchSize float32 ) float32 {
269+ if r .AvgInputTokens == 0 {
259270 return 0
260271 }
261- return p .Gamma + p . Delta * float32 ( avgInputTokens ) * batchSize
272+ return p .IterationTime ( r , batchSize ) + ( p . Beta + p . Gamma ) * r . AvgInputTokens
262273}
263274
264- func (p * DecodeParms ) DecodeTime (batchSize float32 ) float32 {
265- return p .Alpha + p .Beta * batchSize
275+ // Average decode time (generation of ne token) as a function of the batch size
276+ func (p * ServiceParms ) DecodeTime (r * RequestSize , batchSize float32 ) float32 {
277+ return p .IterationTime (r , batchSize ) +
278+ p .Beta + p .Gamma * (r .AvgInputTokens + r .AvgOutputTokens / 2 )
266279}
267280
268281// Function used in binary search (target TTFT)
269282// - x is lambda req/msec
270- func EvalTTFT (x float32 ) (float32 , error ) {
271- Model .Solve (x , 1 )
272- if ! Model .IsValid () {
273- return 0 , fmt .Errorf ("invalid model %s" , Model )
283+ func EvalTTFT (data * EvalFuncData ) func (x float32 ) (float32 , error ) {
284+ return func (x float32 ) (float32 , error ) {
285+ data .model .Solve (x , 1 )
286+ if ! data .model .IsValid () {
287+ return 0 , fmt .Errorf ("invalid model %s" , data .model )
288+ }
289+ avgPrefillTime := data .serviceParms .PrefillTime (data .requestSize , data .model .GetAvgNumInServers ())
290+ avgDecodeTime := (data .model .GetAvgServTime () - avgPrefillTime ) / data .requestSize .AvgOutputTokens
291+ ttft := data .model .GetAvgWaitTime () + avgPrefillTime + avgDecodeTime
292+ return ttft , nil
274293 }
275- avgWaitTime := Model .GetAvgWaitTime ()
276- effConc := EffectiveConcurrency (Model .GetAvgServTime (), evalServiceParms , evalRequestSize , evalMaxBatchSize )
277- ttft := avgWaitTime + evalServiceParms .Prefill .PrefillTime (evalRequestSize .AvgInputTokens , effConc )
278- return ttft , nil
279294}
280295
281296// Function used in binary search (target ITL)
282297// - x is lambda req/msec
283- func EvalITL (x float32 ) (float32 , error ) {
284- Model .Solve (x , 1 )
285- if ! Model .IsValid () {
286- return 0 , fmt .Errorf ("invalid model %s" , Model )
287- }
288- effConc := EffectiveConcurrency (Model .GetAvgServTime (), evalServiceParms , evalRequestSize , evalMaxBatchSize )
289- return evalServiceParms .Decode .DecodeTime (effConc ), nil
290- }
291-
292- // calculate effective average number of requests in service (n), given average request service time
293- // - n has to satisfy: prefillTime(n) + totalDecodeTime(n) = avgServiceTime
294- // - prefillTime(n) = gamma + delta * inTokens * n
295- // - totalDecodeTime(n) = (alpha + beta * n) * (outTokens - 1)
296- func EffectiveConcurrency (avgServiceTime float32 , serviceParms * ServiceParms , requestSize * RequestSize , maxBatchSize int ) float32 {
297- tokens := float32 (requestSize .AvgOutputTokens - 1 )
298- numerator := avgServiceTime - (serviceParms .Prefill .Gamma + serviceParms .Decode .Alpha * tokens )
299- denominator := (serviceParms .Prefill .Delta * float32 (requestSize .AvgInputTokens )) + (serviceParms .Decode .Beta * tokens )
300- n := numerator / denominator
301- return min (max (n , 0 ), float32 (maxBatchSize ))
302- }
303-
304- // check validity of configuration parameters
305- func (c * Configuration ) check () error {
306- if c .MaxBatchSize <= 0 || c .MaxQueueSize < 0 || c .ServiceParms == nil ||
307- c .ServiceParms .Prefill == nil || c .ServiceParms .Decode == nil {
308- return fmt .Errorf ("invalid configuration %s" , c )
309- }
310- return nil
311- }
312-
313- // check validity of request size
314- func (rq * RequestSize ) check () error {
315- if rq .AvgInputTokens < 0 || rq .AvgOutputTokens < 1 {
316- return fmt .Errorf ("invalid request size %s" , rq )
317- }
318- return nil
319- }
320-
321- // check validity of target values
322- func (targetPerf * TargetPerf ) check () error {
323- if targetPerf .TargetITL < 0 ||
324- targetPerf .TargetTTFT < 0 ||
325- targetPerf .TargetTPS < 0 {
326- return fmt .Errorf ("invalid target data values %s" , targetPerf )
298+ func EvalITL (data * EvalFuncData ) func (x float32 ) (float32 , error ) {
299+ return func (x float32 ) (float32 , error ) {
300+ data .model .Solve (x , 1 )
301+ if ! data .model .IsValid () {
302+ return 0 , fmt .Errorf ("invalid model %s" , data .model )
303+ }
304+ avgPrefillTime := data .serviceParms .PrefillTime (data .requestSize , data .model .GetAvgNumInServers ())
305+ avgDecodeTime := (data .model .GetAvgServTime () - avgPrefillTime ) / data .requestSize .AvgOutputTokens
306+ return avgDecodeTime , nil
327307 }
328- return nil
329- }
330-
331- /*
332- * toString() functions
333- */
334-
335- func (c * Configuration ) String () string {
336- return fmt .Sprintf ("{maxBatch=%d, maxQueue=%d, servParms:%s}" ,
337- c .MaxBatchSize , c .MaxQueueSize , c .ServiceParms )
338- }
339-
340- func (qa * QueueAnalyzer ) String () string {
341- return fmt .Sprintf ("{maxBatch=%d, maxQueue=%d, servParms:%s, reqSize:%s, model:%s, rates:%s}" ,
342- qa .MaxBatchSize , qa .MaxQueueSize , qa .ServiceParms , qa .RequestSize , qa .Model , qa .RateRange )
343- }
344-
345- func (sp * ServiceParms ) String () string {
346- return fmt .Sprintf ("{prefillParms=%s, decodeParms=%s}" ,
347- sp .Prefill , sp .Decode )
348- }
349-
350- func (p * PrefillParms ) String () string {
351- return fmt .Sprintf ("{gamma=%.3f, delta=%.5f}" , p .Gamma , p .Delta )
352- }
353-
354- func (p * DecodeParms ) String () string {
355- return fmt .Sprintf ("{alpha=%.3f, beta=%.5f}" , p .Alpha , p .Beta )
356- }
357-
358- func (rq * RequestSize ) String () string {
359- return fmt .Sprintf ("{inTokens=%d, outTokens=%d}" , rq .AvgInputTokens , rq .AvgOutputTokens )
360- }
361-
362- func (rr * RateRange ) String () string {
363- return fmt .Sprintf ("[%.3f, %.3f]" , rr .Min , rr .Max )
364- }
365-
366- func (am * AnalysisMetrics ) String () string {
367- return fmt .Sprintf ("{tput=%.3f, lat=%.3f, wait=%.3f, conc=%.3f, prefill=%.3f, itl=%.3f, maxRate=%.3f, rho=%0.3f}" ,
368- am .Throughput , am .AvgRespTime , am .AvgWaitTime , am .AvgNumInServ , am .AvgPrefillTime , am .AvgTokenTime , am .MaxRate , am .Rho )
369- }
370-
371- func (tp * TargetPerf ) String () string {
372- return fmt .Sprintf ("{TTFT=%.3f, ITL=%.3f, TPS=%.3f}" ,
373- tp .TargetTTFT , tp .TargetITL , tp .TargetTPS )
374- }
375-
376- func (tr * TargetRate ) String () string {
377- return fmt .Sprintf ("{rateTTFT=%.3f, rateITL=%.3f, rateTPS=%.3f}" ,
378- tr .RateTargetTTFT , tr .RateTargetITL , tr .RateTargetTPS )
379308}
0 commit comments