Skip to content

Commit ee02357

Browse files
committed
Apply time factor under load to prefill and inter token latency
Signed-off-by: Qifan Deng <[email protected]>
1 parent 063683b commit ee02357

File tree

5 files changed

+79
-23
lines changed

5 files changed

+79
-23
lines changed

pkg/common/config.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,18 @@ func (c *Configuration) GetTimeToFirstToken(runReqChan *chan int64) int {
186186
return int(float64(c.TimeToFirstToken) * c.calcLoadFactor(runReqChan))
187187
}
188188

189+
func (c *Configuration) GetPrefillOverhead(runReqChan *chan int64) int {
190+
return int(float64(c.PrefillOverhead) * c.calcLoadFactor(runReqChan))
191+
}
192+
193+
func (c *Configuration) GetPrefillTimePerToken(runReqChan *chan int64) int {
194+
return int(float64(c.PrefillTimePerToken) * c.calcLoadFactor(runReqChan))
195+
}
196+
197+
func (c *Configuration) GetInterTokenLatency(runReqChan *chan int64) int {
198+
return int(float64(c.InterTokenLatency) * c.calcLoadFactor(runReqChan))
199+
}
200+
189201
type Metrics struct {
190202
// LoraMetrics
191203
LoraMetrics []LorasMetrics `json:"loras"`

pkg/common/config_test.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,4 +461,48 @@ var _ = Describe("Simulator configuration", func() {
461461
})
462462
})
463463
}
464+
465+
It("when TimeFactorUnderLoad is 1.0, calcLoadFactor should give 1", func() {
466+
c := newConfig()
467+
c.TimeFactorUnderLoad = 1.0
468+
c.MaxNumSeqs = 11
469+
reqChan := make(chan int64, 3)
470+
for i := 0; i < 3; i++ {
471+
reqChan <- 1
472+
}
473+
474+
factor := c.calcLoadFactor(&reqChan)
475+
Expect(factor).To(BeNumerically("==", 1.0))
476+
close(reqChan)
477+
})
478+
479+
It("when TimeFactorUnderLoad is > 1.0, and sim is fully loaded, calcLoadFactor should give TimeFactorUnderLoad", func() {
480+
c := newConfig()
481+
c.TimeFactorUnderLoad = 2.0
482+
c.MaxNumSeqs = 11
483+
reqChan := make(chan int64, c.MaxNumSeqs)
484+
for i := 0; i < c.MaxNumSeqs; i++ {
485+
reqChan <- 1
486+
}
487+
488+
factor := c.calcLoadFactor(&reqChan)
489+
Expect(factor).To(BeNumerically("==", c.TimeFactorUnderLoad))
490+
close(reqChan)
491+
492+
})
493+
494+
It("when TimeFactorUnderLoad is > 1.0, and sim is partially loaded, calcLoadFactor should give a value between 1 and TimeFactorUnderLoad", func() {
495+
c := newConfig()
496+
c.TimeFactorUnderLoad = 2.0
497+
c.MaxNumSeqs = 11
498+
reqChan := make(chan int64, c.MaxNumSeqs)
499+
for i := 0; i < c.MaxNumSeqs/2; i++ {
500+
reqChan <- 1
501+
}
502+
factor := c.calcLoadFactor(&reqChan)
503+
Expect(factor).To(BeNumerically(">", 1.0))
504+
Expect(factor).To(BeNumerically("<", c.TimeFactorUnderLoad))
505+
close(reqChan)
506+
507+
})
464508
})

pkg/llm-d-inference-sim/simulator.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,7 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
537537
finishReason = common.RemoteDecodeFinishReason
538538
}
539539

540-
s.sendResponse(reqCtx, responseTokens, toolCalls, displayModel, finishReason, &s.runReqChan, &usageData)
540+
s.sendResponse(reqCtx, responseTokens, toolCalls, displayModel, finishReason, &usageData)
541541
}
542542
}
543543
reqCtx.Wg.Done()
@@ -662,7 +662,7 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respToke
662662
// finishReason - a pointer to string that represents finish reason, can be nil, stop, length, or tools
663663
// usageData - usage (tokens statistics) for this response
664664
func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, respTokens []string, toolCalls []openaiserverapi.ToolCall,
665-
modelName string, finishReason string, runReqChan *chan int64, usageData *openaiserverapi.Usage) {
665+
modelName string, finishReason string, usageData *openaiserverapi.Usage) {
666666
resp := s.createCompletionResponse(reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName,
667667
reqCtx.CompletionReq.IsDoRemoteDecode())
668668

@@ -677,7 +677,7 @@ func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, r
677677
nPromptTokens := usageData.PromptTokens
678678
nCachedPromptTokens := reqCtx.CompletionReq.GetNumberOfCachedPromptTokens()
679679
nGenTokens := usageData.CompletionTokens
680-
ttft := s.getTimeToFirstToken(nPromptTokens, nCachedPromptTokens, reqCtx.CompletionReq.IsDoRemotePrefill(), runReqChan)
680+
ttft := s.getTimeToFirstToken(nPromptTokens, nCachedPromptTokens, reqCtx.CompletionReq.IsDoRemotePrefill())
681681
totalMillisToWait := ttft + s.getTotalInterTokenLatency(nGenTokens)
682682
time.Sleep(time.Duration(totalMillisToWait) * time.Millisecond)
683683

@@ -696,7 +696,7 @@ func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, r
696696
}
697697

698698
// returns time to first token based on the current request's doRemotePrefill
699-
func (s *VllmSimulator) getTimeToFirstToken(nPromptTokens int, nCachedPromptTokens int, doRemotePrefill bool, runReqChan *chan int64) int {
699+
func (s *VllmSimulator) getTimeToFirstToken(nPromptTokens int, nCachedPromptTokens int, doRemotePrefill bool) int {
700700
if doRemotePrefill {
701701
if s.config.KVCacheTransferLatency == 0 && s.config.KVCacheTransferLatencyStdDev == 0 {
702702
// is disaggregated PD and ttft is calculated using number of prompt tokens
@@ -708,16 +708,16 @@ func (s *VllmSimulator) getTimeToFirstToken(nPromptTokens int, nCachedPromptToke
708708
}
709709
if s.config.TimeToFirstToken == 0 && s.config.TimeToFirstTokenStdDev == 0 {
710710
// is aggregated PD and ttft is calculated using number of prompt tokens that are not in kv cache
711-
prefillTime := s.config.PrefillOverhead + (nPromptTokens-nCachedPromptTokens)*s.config.PrefillTimePerToken
711+
prefillTime := s.config.GetPrefillOverhead(&s.runReqChan) + (nPromptTokens-nCachedPromptTokens)*s.config.GetPrefillTimePerToken(&s.runReqChan)
712712
return int(common.RandomNorm(float64(prefillTime), float64(s.config.PrefillTimeStdDev)))
713713
}
714714
// is aggregated PD and *not* using number of prompt tokens
715-
return int(common.RandomNorm(float64(s.config.GetTimeToFirstToken(runReqChan)), float64(s.config.TimeToFirstTokenStdDev)))
715+
return int(common.RandomNorm(float64(s.config.GetTimeToFirstToken(&s.runReqChan)), float64(s.config.TimeToFirstTokenStdDev)))
716716
}
717717

718718
// returns inter token latency
719719
func (s *VllmSimulator) getInterTokenLatency() int {
720-
mean := float64(s.config.InterTokenLatency)
720+
mean := float64(s.config.GetInterTokenLatency(&s.runReqChan))
721721
stddev := float64(s.config.InterTokenLatencyStdDev)
722722
return int(common.RandomNorm(mean, stddev))
723723
}

pkg/llm-d-inference-sim/simulator_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ var _ = Describe("Simulator", func() {
798798
simulator.config.TimeToFirstTokenStdDev = timeToFirstTokenStdDev
799799
simulator.config.KVCacheTransferLatency = kvCacheLatency
800800
simulator.config.KVCacheTransferLatencyStdDev = kvCacheLatencyStdDev
801-
timeToFirst := simulator.getTimeToFirstToken(1, 0, doREmotePrefill, &simulator.runReqChan)
801+
timeToFirst := simulator.getTimeToFirstToken(1, 0, doREmotePrefill)
802802
if doREmotePrefill {
803803
Expect(timeToFirst).To(BeNumerically(">=", int(float32(kvCacheLatency)*0.3)))
804804
Expect(timeToFirst).To(BeNumerically("<=", int(float32(kvCacheLatency)*1.7)))
@@ -829,7 +829,7 @@ var _ = Describe("Simulator", func() {
829829
simulator.config.PrefillTimePerToken = 200
830830
simulator.config.PrefillTimeStdDev = 80
831831

832-
ttft := simulator.getTimeToFirstToken(128, 0, false, &simulator.runReqChan)
832+
ttft := simulator.getTimeToFirstToken(128, 0, false)
833833

834834
Expect(ttft).To(BeNumerically("==", timeToFirstToken))
835835
})
@@ -842,7 +842,7 @@ var _ = Describe("Simulator", func() {
842842
simulator.config.PrefillTimePerToken = 200
843843
simulator.config.PrefillTimeStdDev = 80
844844

845-
ttft := simulator.getTimeToFirstToken(128, 0, false, &simulator.runReqChan)
845+
ttft := simulator.getTimeToFirstToken(128, 0, false)
846846
Expect(ttft).NotTo(BeNumerically("==", 0))
847847
})
848848

@@ -853,7 +853,7 @@ var _ = Describe("Simulator", func() {
853853
simulator.config.PrefillTimePerToken = prefillTimePerToken
854854
simulator.config.PrefillTimeStdDev = stdDev
855855

856-
ttft := simulator.getTimeToFirstToken(nTokens, nCachedTokens, false, &simulator.runReqChan)
856+
ttft := simulator.getTimeToFirstToken(nTokens, nCachedTokens, false)
857857

858858
expectedTTFT := prefillOverhead + prefillTimePerToken*(nTokens-nCachedTokens)
859859
Expect(ttft).To(BeNumerically(">=", int(float64(expectedTTFT)*0.3)))
@@ -881,7 +881,7 @@ var _ = Describe("Simulator", func() {
881881
simulator.config.PrefillTimePerToken = prefillTimePerToken
882882
simulator.config.PrefillTimeStdDev = 0
883883

884-
ttft := simulator.getTimeToFirstToken(nTokens, nCachedTokens, false, &simulator.runReqChan)
884+
ttft := simulator.getTimeToFirstToken(nTokens, nCachedTokens, false)
885885
expectedTTFT := prefillOverhead + prefillTimePerToken*(nTokens-nCachedTokens)
886886
Expect(ttft).To(Equal(expectedTTFT))
887887
},
@@ -905,7 +905,7 @@ var _ = Describe("Simulator", func() {
905905
simulator.config.KVCacheTransferTimePerToken = 100
906906
simulator.config.KVCacheTransferTimeStdDev = 0
907907

908-
ttft := simulator.getTimeToFirstToken(128, 0, true, &simulator.runReqChan)
908+
ttft := simulator.getTimeToFirstToken(128, 0, true)
909909
Expect(ttft).To(BeNumerically("==", 200))
910910
})
911911

@@ -916,7 +916,7 @@ var _ = Describe("Simulator", func() {
916916
simulator.config.KVCacheTransferTimePerToken = 100
917917
simulator.config.KVCacheTransferTimeStdDev = 0
918918

919-
ttft := simulator.getTimeToFirstToken(128, 0, true, &simulator.runReqChan)
919+
ttft := simulator.getTimeToFirstToken(128, 0, true)
920920
Expect(ttft).To(BeNumerically("==", 12800))
921921
})
922922

@@ -927,7 +927,7 @@ var _ = Describe("Simulator", func() {
927927
simulator.config.KVCacheTransferTimePerToken = kvCacheTransTPT
928928
simulator.config.KVCacheTransferTimeStdDev = stddev
929929

930-
ttft := simulator.getTimeToFirstToken(nTokens, 0, true, &simulator.runReqChan)
930+
ttft := simulator.getTimeToFirstToken(nTokens, 0, true)
931931

932932
expectedTTFT := kvCacheTransTPT * nTokens
933933
Expect(ttft).To(BeNumerically(">=", int(float64(expectedTTFT)*0.3)))
@@ -952,7 +952,7 @@ var _ = Describe("Simulator", func() {
952952

953953
simulator.runReqChan <- 100
954954

955-
ttft := simulator.getTimeToFirstToken(128, 0, false, &simulator.runReqChan)
955+
ttft := simulator.getTimeToFirstToken(128, 0, false)
956956
Expect(ttft).To(Equal(42))
957957
})
958958

@@ -968,7 +968,7 @@ var _ = Describe("Simulator", func() {
968968

969969
simulator.runReqChan <- 1
970970

971-
ttft := simulator.getTimeToFirstToken(128, 0, false, &simulator.runReqChan)
971+
ttft := simulator.getTimeToFirstToken(128, 0, false)
972972
Expect(ttft).To(Equal(42))
973973
})
974974

@@ -985,7 +985,7 @@ var _ = Describe("Simulator", func() {
985985
simulator.runReqChan <- 1
986986
}
987987

988-
ttft := simulator.getTimeToFirstToken(128, 0, false, &simulator.runReqChan)
988+
ttft := simulator.getTimeToFirstToken(128, 0, false)
989989
Expect(ttft).To(Equal(int(float64(42) * timeFactorUnderLoad)))
990990

991991
},
@@ -1014,7 +1014,7 @@ var _ = Describe("Simulator", func() {
10141014
simulator.runReqChan <- 1
10151015
}
10161016

1017-
ttft := simulator.getTimeToFirstToken(128, 0, false, &simulator.runReqChan)
1017+
ttft := simulator.getTimeToFirstToken(128, 0, false)
10181018
max := timeFactorUnderLoad * float64(42)
10191019
Expect(ttft).To(BeNumerically(">=", 42))
10201020
Expect(ttft).To(BeNumerically("<=", max))

pkg/llm-d-inference-sim/streaming.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,11 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, respons
6969
if len(toolCalls) > 0 {
7070
s.logger.Info("Going to send tools calls")
7171
for _, tc := range toolCalls {
72-
s.sendTokenChunks(context, w, tc.Function.TokenizedArguments, &tc, finishReason, &s.runReqChan)
72+
s.sendTokenChunks(context, w, tc.Function.TokenizedArguments, &tc, finishReason)
7373
}
7474
} else {
7575
s.logger.Info("Going to send text", "number of tokens", len(responseTokens))
76-
s.sendTokenChunks(context, w, responseTokens, nil, finishReason, &s.runReqChan)
76+
s.sendTokenChunks(context, w, responseTokens, nil, finishReason)
7777
}
7878
}
7979

@@ -97,9 +97,9 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, respons
9797

9898
// sendTokenChunks creates and sends response chunks
9999
func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writer, genTokens []string,
100-
tc *openaiserverapi.ToolCall, finishReason string, runReqChan *chan int64) {
100+
tc *openaiserverapi.ToolCall, finishReason string) {
101101
// time to first token delay
102-
ttft := s.getTimeToFirstToken(context.nPromptTokens, context.nCachedPromptTokens, context.doRemotePrefill, runReqChan)
102+
ttft := s.getTimeToFirstToken(context.nPromptTokens, context.nCachedPromptTokens, context.doRemotePrefill)
103103
time.Sleep(time.Duration(ttft) * time.Millisecond)
104104

105105
for i, token := range genTokens {

0 commit comments

Comments
 (0)