Skip to content

Commit aff15ca

Browse files
authored
enhance queueing model used by queue analyzer (#727)
1 parent e0cab8c commit aff15ca

File tree

11 files changed

+326
-465
lines changed

11 files changed

+326
-465
lines changed

pkg/analyzer/queueanalyzer.go

Lines changed: 89 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,13 @@ const Epsilon = float32(0.001)
1010
// fraction of maximum server throughput to provide stability (running this fraction below the maximum)
1111
const StabilitySafetyFraction = float32(0.1)
1212

13+
// maximum number of tokens per batch (iteration)
14+
const DefaultMaxNumTokens = 8192
15+
1316
// Analyzer of inference server queue
1417
type QueueAnalyzer struct {
1518
MaxBatchSize int // maximum batch size
19+
MaxNumTokens int // maximum number of tokens per batch
1620
MaxQueueSize int // maximum queue size
1721
ServiceParms *ServiceParms // request processing parameters
1822
RequestSize *RequestSize // number of input and output tokens per request
@@ -23,32 +27,23 @@ type QueueAnalyzer struct {
2327
// queue configuration parameters
2428
type Configuration struct {
2529
MaxBatchSize int // maximum batch size (limit on the number of requests concurrently receiving service >0)
30+
MaxNumTokens int // maximum number of tokens per batch (limit on the number of tokens per batch >0)
2631
MaxQueueSize int // maximum queue size (limit on the number of requests queued for servive >=0)
2732
ServiceParms *ServiceParms // request processing parameters
2833
}
2934

30-
// request processing parameters
35+
// request processing parameters:
36+
// iterationTime = alpha + beta * computeTime + gamma * memoryAccessTime
3137
type ServiceParms struct {
32-
Prefill *PrefillParms // parameters to calculate prefill time
33-
Decode *DecodeParms // parameters to calculate decode time
34-
}
35-
36-
// prefill time = gamma + delta * inputTokens * batchSize (msec); inputTokens > 0
37-
type PrefillParms struct {
38-
Gamma float32 // base
39-
Delta float32 // slope
40-
}
41-
42-
// decode time = alpha + beta * batchSize (msec); batchSize > 0
43-
type DecodeParms struct {
4438
Alpha float32 // base
45-
Beta float32 // slope
39+
Beta float32 // slope for compute time
40+
Gamma float32 // slope for memory access time
4641
}
4742

4843
// request tokens data
4944
type RequestSize struct {
50-
AvgInputTokens int // average number of input tokens per request
51-
AvgOutputTokens int // average number of output tokens per request
45+
AvgInputTokens float32 // average number of input tokens per request
46+
AvgOutputTokens float32 // average number of output tokens per request
5247
}
5348

5449
// range of request rates (requests/sec)
@@ -65,6 +60,7 @@ type AnalysisMetrics struct {
6560
AvgNumInServ float32 // average number of requests in service
6661
AvgPrefillTime float32 // average request prefill time (msec)
6762
AvgTokenTime float32 // average token decode time (msec)
63+
AvgTTFT float32 // average time to first token (msec)
6864
MaxRate float32 // maximum throughput (requests/sec)
6965
Rho float32 // utilization
7066
}
@@ -96,35 +92,32 @@ func NewQueueAnalyzer(qConfig *Configuration, requestSize *RequestSize) (*QueueA
9692
}
9793

9894
// build queueing model using service rates, leaving arrival rate as parameter
99-
func BuildModel(qConfig *Configuration, requestSize *RequestSize) (modelData *QueueAnalyzer) {
100-
parms := qConfig.ServiceParms
95+
func BuildModel(c *Configuration, r *RequestSize) (modelData *QueueAnalyzer) {
96+
parms := c.ServiceParms
10197

10298
// calculate state-dependent service rate
103-
servRate := make([]float32, qConfig.MaxBatchSize)
104-
for n := 1; n <= qConfig.MaxBatchSize; n++ {
105-
prefillTime := parms.Prefill.PrefillTime(requestSize.AvgInputTokens, float32(n))
106-
numDecode := requestSize.AvgOutputTokens - 1 // number of decodes (one per output token except the first)
107-
// special case: allow one decode in case of decode only and one output token
108-
if requestSize.AvgInputTokens == 0 && requestSize.AvgOutputTokens == 1 {
109-
numDecode = 1
110-
}
111-
decodeTime := float32(numDecode) * parms.Decode.DecodeTime(float32(n))
99+
servRate := make([]float32, c.MaxBatchSize)
100+
for n := 1; n <= c.MaxBatchSize; n++ {
101+
prefillTime := parms.PrefillTime(r, float32(n))
102+
decodeTime := r.AvgOutputTokens * parms.DecodeTime(r, float32(n))
112103
servRate[n-1] = float32(n) / (prefillTime + decodeTime)
113104
}
114105

115106
// set and check limits
116107
lambdaMin := servRate[0] * Epsilon
117-
lambdaMax := servRate[qConfig.MaxBatchSize-1] * (1 - Epsilon)
108+
lambdaMax := servRate[c.MaxBatchSize-1] * (1 - Epsilon)
118109
rateRange := &RateRange{Min: lambdaMin * 1000, Max: lambdaMax * 1000}
119110

120111
// create and solve model
121-
occupancyUpperBound := qConfig.MaxQueueSize + qConfig.MaxBatchSize
112+
occupancyUpperBound := c.MaxQueueSize + c.MaxBatchSize
122113
model := NewMM1ModelStateDependent(occupancyUpperBound, servRate)
114+
123115
return &QueueAnalyzer{
124-
MaxBatchSize: qConfig.MaxBatchSize,
125-
MaxQueueSize: qConfig.MaxQueueSize,
116+
MaxBatchSize: c.MaxBatchSize,
117+
MaxNumTokens: c.MaxNumTokens,
118+
MaxQueueSize: c.MaxQueueSize,
126119
ServiceParms: parms,
127-
RequestSize: requestSize,
120+
RequestSize: r,
128121
Model: model,
129122
RateRange: rateRange,
130123
}
@@ -151,10 +144,9 @@ func (qa *QueueAnalyzer) Analyze(requestRate float32) (metrics *AnalysisMetrics,
151144

152145
// get statistics
153146
avgNumInServ := model.GetAvgNumInServers()
154-
155-
effConc := EffectiveConcurrency(model.GetAvgServTime(), qa.ServiceParms, qa.RequestSize, qa.MaxBatchSize)
156-
prefillTime := qa.ServiceParms.Prefill.PrefillTime(qa.RequestSize.AvgInputTokens, effConc)
157-
tokenTime := qa.ServiceParms.Decode.DecodeTime(effConc)
147+
avgPrefillTime := qa.ServiceParms.PrefillTime(qa.RequestSize, avgNumInServ)
148+
avgDecodeTime := (model.GetAvgServTime() - avgPrefillTime) / qa.RequestSize.AvgOutputTokens
149+
avgTTFT := model.GetAvgWaitTime() + avgPrefillTime + avgDecodeTime
158150

159151
rho := avgNumInServ / float32(qa.MaxBatchSize)
160152
rho = min(max(rho, 0), 1)
@@ -165,18 +157,22 @@ func (qa *QueueAnalyzer) Analyze(requestRate float32) (metrics *AnalysisMetrics,
165157
AvgRespTime: model.GetAvgRespTime(),
166158
AvgWaitTime: model.GetAvgWaitTime(),
167159
AvgNumInServ: avgNumInServ,
168-
AvgPrefillTime: prefillTime,
169-
AvgTokenTime: tokenTime,
160+
AvgPrefillTime: avgPrefillTime,
161+
AvgTokenTime: avgDecodeTime,
162+
AvgTTFT: avgTTFT,
170163
MaxRate: rateRange.Max,
171164
Rho: rho,
172165
}
173166
return metrics, nil
174167
}
175168

176-
// global variables used by eval functions, to be set before calling eval function
177-
var evalRequestSize *RequestSize // number of input and output tokens per request
178-
var evalServiceParms *ServiceParms // request processing parameters for prefill and decode stages
179-
var evalMaxBatchSize int // max batch size
169+
// model and parameters used in functional evaluation
170+
type EvalFuncData struct {
171+
model *MM1ModelStateDependent // queueing model
172+
requestSize *RequestSize // number of input and output tokens per request
173+
serviceParms *ServiceParms // request processing parameters for prefill and decode stages
174+
maxBatchSize int // max batch size
175+
}
180176

181177
// evaluate max request rates to achieve a given target performance, returns
182178
// - max request rates
@@ -193,18 +189,19 @@ func (qa *QueueAnalyzer) Size(targetPerf *TargetPerf) (targetRate *TargetRate, m
193189
lambdaMin := qa.RateRange.Min / 1000
194190
lambdaMax := qa.RateRange.Max / 1000
195191

196-
// set global variables for model and parameters used in functional evaluation
197-
Model = qa.Model
198-
evalRequestSize = qa.RequestSize
199-
evalServiceParms = qa.ServiceParms
200-
evalMaxBatchSize = qa.MaxBatchSize
201-
192+
// indicator value returned by binary search
202193
var ind int
203194

204195
// find max rate to achieve target TTFT time
205196
lambdaStarTTFT := lambdaMax
206197
if targetTTFT > 0 {
207-
lambdaStarTTFT, ind, err = BinarySearch(lambdaMin, lambdaMax, targetTTFT, EvalTTFT)
198+
evalTTF := EvalTTFT(&EvalFuncData{
199+
model: qa.Model,
200+
requestSize: qa.RequestSize,
201+
serviceParms: qa.ServiceParms,
202+
maxBatchSize: qa.MaxBatchSize,
203+
})
204+
lambdaStarTTFT, ind, err = BinarySearch(lambdaMin, lambdaMax, targetTTFT, evalTTF)
208205
if ind < 0 {
209206
err = fmt.Errorf("target is below the bounded region")
210207
}
@@ -217,7 +214,13 @@ func (qa *QueueAnalyzer) Size(targetPerf *TargetPerf) (targetRate *TargetRate, m
217214
// find max rate to achieve target ITL time
218215
lambdaStarITL := lambdaMax
219216
if targetITL > 0 {
220-
lambdaStarITL, ind, err = BinarySearch(lambdaMin, lambdaMax, targetITL, EvalITL)
217+
evalITL := EvalITL(&EvalFuncData{
218+
model: qa.Model,
219+
requestSize: qa.RequestSize,
220+
serviceParms: qa.ServiceParms,
221+
maxBatchSize: qa.MaxBatchSize,
222+
})
223+
lambdaStarITL, ind, err = BinarySearch(lambdaMin, lambdaMax, targetITL, evalITL)
221224
if ind < 0 {
222225
err = fmt.Errorf("target is below the bounded region")
223226
}
@@ -247,133 +250,59 @@ func (qa *QueueAnalyzer) Size(targetPerf *TargetPerf) (targetRate *TargetRate, m
247250
}
248251

249252
achieved = &TargetPerf{
250-
TargetTTFT: metrics.AvgWaitTime + metrics.AvgPrefillTime,
253+
TargetTTFT: metrics.AvgTTFT,
251254
TargetITL: metrics.AvgTokenTime,
252-
TargetTPS: metrics.Throughput * float32(qa.RequestSize.AvgOutputTokens),
255+
TargetTPS: metrics.Throughput * qa.RequestSize.AvgOutputTokens,
253256
}
254257
return targetRate, metrics, achieved, nil
255258
}
256259

257-
func (p *PrefillParms) PrefillTime(avgInputTokens int, batchSize float32) float32 {
258-
if avgInputTokens == 0 {
260+
// Average iteration time as a function of the batch size T(n)
261+
func (p *ServiceParms) IterationTime(r *RequestSize, batchSize float32) float32 {
262+
tokensCompute := (r.AvgInputTokens + r.AvgOutputTokens) / (r.AvgOutputTokens + 1)
263+
tokensMemory := r.AvgInputTokens + r.AvgOutputTokens/2
264+
return p.Alpha + batchSize*(p.Beta*tokensCompute+p.Gamma*tokensMemory)
265+
}
266+
267+
// Average prefill time as a function of the batch size
268+
func (p *ServiceParms) PrefillTime(r *RequestSize, batchSize float32) float32 {
269+
if r.AvgInputTokens == 0 {
259270
return 0
260271
}
261-
return p.Gamma + p.Delta*float32(avgInputTokens)*batchSize
272+
return p.IterationTime(r, batchSize) + (p.Beta+p.Gamma)*r.AvgInputTokens
262273
}
263274

264-
func (p *DecodeParms) DecodeTime(batchSize float32) float32 {
265-
return p.Alpha + p.Beta*batchSize
275+
// Average decode time (generation of ne token) as a function of the batch size
276+
func (p *ServiceParms) DecodeTime(r *RequestSize, batchSize float32) float32 {
277+
return p.IterationTime(r, batchSize) +
278+
p.Beta + p.Gamma*(r.AvgInputTokens+r.AvgOutputTokens/2)
266279
}
267280

268281
// Function used in binary search (target TTFT)
269282
// - x is lambda req/msec
270-
func EvalTTFT(x float32) (float32, error) {
271-
Model.Solve(x, 1)
272-
if !Model.IsValid() {
273-
return 0, fmt.Errorf("invalid model %s", Model)
283+
func EvalTTFT(data *EvalFuncData) func(x float32) (float32, error) {
284+
return func(x float32) (float32, error) {
285+
data.model.Solve(x, 1)
286+
if !data.model.IsValid() {
287+
return 0, fmt.Errorf("invalid model %s", data.model)
288+
}
289+
avgPrefillTime := data.serviceParms.PrefillTime(data.requestSize, data.model.GetAvgNumInServers())
290+
avgDecodeTime := (data.model.GetAvgServTime() - avgPrefillTime) / data.requestSize.AvgOutputTokens
291+
ttft := data.model.GetAvgWaitTime() + avgPrefillTime + avgDecodeTime
292+
return ttft, nil
274293
}
275-
avgWaitTime := Model.GetAvgWaitTime()
276-
effConc := EffectiveConcurrency(Model.GetAvgServTime(), evalServiceParms, evalRequestSize, evalMaxBatchSize)
277-
ttft := avgWaitTime + evalServiceParms.Prefill.PrefillTime(evalRequestSize.AvgInputTokens, effConc)
278-
return ttft, nil
279294
}
280295

281296
// Function used in binary search (target ITL)
282297
// - x is lambda req/msec
283-
func EvalITL(x float32) (float32, error) {
284-
Model.Solve(x, 1)
285-
if !Model.IsValid() {
286-
return 0, fmt.Errorf("invalid model %s", Model)
287-
}
288-
effConc := EffectiveConcurrency(Model.GetAvgServTime(), evalServiceParms, evalRequestSize, evalMaxBatchSize)
289-
return evalServiceParms.Decode.DecodeTime(effConc), nil
290-
}
291-
292-
// calculate effective average number of requests in service (n), given average request service time
293-
// - n has to satisfy: prefillTime(n) + totalDecodeTime(n) = avgServiceTime
294-
// - prefillTime(n) = gamma + delta * inTokens * n
295-
// - totalDecodeTime(n) = (alpha + beta * n) * (outTokens - 1)
296-
func EffectiveConcurrency(avgServiceTime float32, serviceParms *ServiceParms, requestSize *RequestSize, maxBatchSize int) float32 {
297-
tokens := float32(requestSize.AvgOutputTokens - 1)
298-
numerator := avgServiceTime - (serviceParms.Prefill.Gamma + serviceParms.Decode.Alpha*tokens)
299-
denominator := (serviceParms.Prefill.Delta * float32(requestSize.AvgInputTokens)) + (serviceParms.Decode.Beta * tokens)
300-
n := numerator / denominator
301-
return min(max(n, 0), float32(maxBatchSize))
302-
}
303-
304-
// check validity of configuration parameters
305-
func (c *Configuration) check() error {
306-
if c.MaxBatchSize <= 0 || c.MaxQueueSize < 0 || c.ServiceParms == nil ||
307-
c.ServiceParms.Prefill == nil || c.ServiceParms.Decode == nil {
308-
return fmt.Errorf("invalid configuration %s", c)
309-
}
310-
return nil
311-
}
312-
313-
// check validity of request size
314-
func (rq *RequestSize) check() error {
315-
if rq.AvgInputTokens < 0 || rq.AvgOutputTokens < 1 {
316-
return fmt.Errorf("invalid request size %s", rq)
317-
}
318-
return nil
319-
}
320-
321-
// check validity of target values
322-
func (targetPerf *TargetPerf) check() error {
323-
if targetPerf.TargetITL < 0 ||
324-
targetPerf.TargetTTFT < 0 ||
325-
targetPerf.TargetTPS < 0 {
326-
return fmt.Errorf("invalid target data values %s", targetPerf)
298+
func EvalITL(data *EvalFuncData) func(x float32) (float32, error) {
299+
return func(x float32) (float32, error) {
300+
data.model.Solve(x, 1)
301+
if !data.model.IsValid() {
302+
return 0, fmt.Errorf("invalid model %s", data.model)
303+
}
304+
avgPrefillTime := data.serviceParms.PrefillTime(data.requestSize, data.model.GetAvgNumInServers())
305+
avgDecodeTime := (data.model.GetAvgServTime() - avgPrefillTime) / data.requestSize.AvgOutputTokens
306+
return avgDecodeTime, nil
327307
}
328-
return nil
329-
}
330-
331-
/*
332-
* toString() functions
333-
*/
334-
335-
func (c *Configuration) String() string {
336-
return fmt.Sprintf("{maxBatch=%d, maxQueue=%d, servParms:%s}",
337-
c.MaxBatchSize, c.MaxQueueSize, c.ServiceParms)
338-
}
339-
340-
func (qa *QueueAnalyzer) String() string {
341-
return fmt.Sprintf("{maxBatch=%d, maxQueue=%d, servParms:%s, reqSize:%s, model:%s, rates:%s}",
342-
qa.MaxBatchSize, qa.MaxQueueSize, qa.ServiceParms, qa.RequestSize, qa.Model, qa.RateRange)
343-
}
344-
345-
func (sp *ServiceParms) String() string {
346-
return fmt.Sprintf("{prefillParms=%s, decodeParms=%s}",
347-
sp.Prefill, sp.Decode)
348-
}
349-
350-
func (p *PrefillParms) String() string {
351-
return fmt.Sprintf("{gamma=%.3f, delta=%.5f}", p.Gamma, p.Delta)
352-
}
353-
354-
func (p *DecodeParms) String() string {
355-
return fmt.Sprintf("{alpha=%.3f, beta=%.5f}", p.Alpha, p.Beta)
356-
}
357-
358-
func (rq *RequestSize) String() string {
359-
return fmt.Sprintf("{inTokens=%d, outTokens=%d}", rq.AvgInputTokens, rq.AvgOutputTokens)
360-
}
361-
362-
func (rr *RateRange) String() string {
363-
return fmt.Sprintf("[%.3f, %.3f]", rr.Min, rr.Max)
364-
}
365-
366-
func (am *AnalysisMetrics) String() string {
367-
return fmt.Sprintf("{tput=%.3f, lat=%.3f, wait=%.3f, conc=%.3f, prefill=%.3f, itl=%.3f, maxRate=%.3f, rho=%0.3f}",
368-
am.Throughput, am.AvgRespTime, am.AvgWaitTime, am.AvgNumInServ, am.AvgPrefillTime, am.AvgTokenTime, am.MaxRate, am.Rho)
369-
}
370-
371-
func (tp *TargetPerf) String() string {
372-
return fmt.Sprintf("{TTFT=%.3f, ITL=%.3f, TPS=%.3f}",
373-
tp.TargetTTFT, tp.TargetITL, tp.TargetTPS)
374-
}
375-
376-
func (tr *TargetRate) String() string {
377-
return fmt.Sprintf("{rateTTFT=%.3f, rateITL=%.3f, rateTPS=%.3f}",
378-
tr.RateTargetTTFT, tr.RateTargetITL, tr.RateTargetTPS)
379308
}

0 commit comments

Comments
 (0)