Skip to content

Commit a79c33d

Browse files
committed
Add feature of calc ttft by prefill overhead. TODO: kvcache transfer overhead
Signed-off-by: Qifan Deng <[email protected]>
1 parent 0fe5652 commit a79c33d

File tree

5 files changed

+116
-12
lines changed

5 files changed

+116
-12
lines changed

pkg/common/config.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ type Configuration struct {
6565
// in milliseconds, optional, default is 0, can't be more than 30% of TimeToFirstToken, will not
6666
// cause the actual time to first token to differ by more than 70% from TimeToFirstToken
6767
TimeToFirstTokenStdDev int `yaml:"time-to-first-token-std-dev" json:"time-to-first-token-std-dev"`
68+
69+
// PrefillOverhead time taken to prefill the context, in milliseconds
70+
PrefillOverhead int `yaml:"prefill-overhead" json:"prefill-overhead"`
71+
PrefillOverheadComplexity string `yaml:"prefill-overhead-complexity" json:"prefill-overhead-complexity"`
72+
6873
// InterTokenLatency time between generated tokens, in milliseconds
6974
InterTokenLatency int `yaml:"inter-token-latency" json:"inter-token-latency"`
7075
// InterTokenLatencyStdDev standard deviation for time between generated tokens, in milliseconds,
@@ -295,6 +300,16 @@ func (c *Configuration) validate() error {
295300
if float32(c.TimeToFirstTokenStdDev) > 0.3*float32(c.TimeToFirstToken) {
296301
return errors.New("time to first token standard deviation cannot be more than 30% of time to first token")
297302
}
303+
if c.PrefillOverhead < 0 {
304+
return errors.New("prefill overhead cannot be negative")
305+
} else if c.PrefillOverhead == 0 {
306+
if c.PrefillOverheadComplexity != "" {
307+
return errors.New("prefill overhead complexity is set, but prefill overhead is 0")
308+
}
309+
}
310+
if c.PrefillOverheadComplexity != "" && c.PrefillOverheadComplexity != "n^2" && c.PrefillOverheadComplexity != "nlog(n)" {
311+
return errors.New("prefill overhead complexity should be either \"n^2\" or \"nlog(n)\"")
312+
}
298313
if c.KVCacheTransferLatency < 0 {
299314
return errors.New("kv-cache tranfer time cannot be negative")
300315
}
@@ -400,6 +415,8 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) {
400415
f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode, echo - returns the same text that was sent in the request, for chat completion returns the last message, random - returns random sentence from a bank of pre-defined sentences")
401416
f.IntVar(&config.InterTokenLatency, "inter-token-latency", config.InterTokenLatency, "Time to generate one token (in milliseconds)")
402417
f.IntVar(&config.TimeToFirstToken, "time-to-first-token", config.TimeToFirstToken, "Time to first token (in milliseconds)")
418+
f.IntVar(&config.PrefillOverhead, "prefill-overhead", config.PrefillOverhead, "Time to prefill in milliseconds. This argument is ignored if <time-to-first-token> is not 0.")
419+
f.StringVar(&config.PrefillOverheadComplexity, "prefill-overhead-complexity", config.PrefillOverheadComplexity, "Complexity of prefill based on token length. Options are \"n^2\" and \"nlog(n)\". Default is \"n^2\".")
403420
f.IntVar(&config.KVCacheTransferLatency, "kv-cache-transfer-latency", config.KVCacheTransferLatency, "Time for KV-cache transfer from a remote vLLM (in milliseconds)")
404421
f.IntVar(&config.InterTokenLatencyStdDev, "inter-token-latency-std-dev", config.InterTokenLatencyStdDev, "Standard deviation for time between generated tokens (in milliseconds)")
405422
f.IntVar(&config.TimeToFirstTokenStdDev, "time-to-first-token-std-dev", config.TimeToFirstTokenStdDev, "Standard deviation for time before the first token will be returned (in milliseconds)")

pkg/common/config_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,10 @@ var _ = Describe("Simulator configuration", func() {
388388
name: "invalid (negative) zmq-max-connect-attempts for config file",
389389
args: []string{"cmd", "--config", "../../manifests/invalid-config.yaml"},
390390
},
391+
{
392+
name: "<prefill-overhead> must be set when <prefill-overhead-complexity> is set",
393+
args: []string{"cmd", "--prefill-overhead-complexity", "n^2", "--config", "../../manifests/config.yaml"},
394+
},
391395
}
392396

393397
for _, test := range invalidTests {

pkg/llm-d-inference-sim/simulator.go

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"encoding/json"
2323
"errors"
2424
"fmt"
25+
"math"
2526
"net"
2627
"os"
2728
"strings"
@@ -465,7 +466,7 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) {
465466
model: displayModel,
466467
doRemotePrefill: req.IsDoRemotePrefill(),
467468
},
468-
responseTokens, toolCalls, finishReason, usageDataToSend,
469+
usageDataToSend.PromptTokens, responseTokens, toolCalls, finishReason, usageDataToSend,
469470
)
470471
} else {
471472
if req.IsDoRemoteDecode() {
@@ -633,8 +634,9 @@ func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.Reques
633634
}
634635

635636
// calculate how long to wait before returning the response, time is based on number of tokens
636-
numOfTokens := usageData.CompletionTokens
637-
totalMillisToWait := s.getTimeToFirstToken(doRemotePrefill) + s.getTotalInterTokenLatency(numOfTokens)
637+
nPromptTokens := usageData.PromptTokens
638+
nGenTokens := usageData.CompletionTokens
639+
totalMillisToWait := s.getTimeToFirstToken(doRemotePrefill, nPromptTokens) + s.getTotalInterTokenLatency(nGenTokens)
638640
time.Sleep(time.Duration(totalMillisToWait) * time.Millisecond)
639641

640642
ctx.Response.Header.SetContentType("application/json")
@@ -652,7 +654,14 @@ func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.Reques
652654
}
653655

654656
// returns time to first token based on the current request's doRemotePrefill
655-
func (s *VllmSimulator) getTimeToFirstToken(doRemotePrefill bool) int {
657+
func (s *VllmSimulator) getTimeToFirstToken(doRemotePrefill bool, nPromptTokens int) int {
658+
if s.config.TimeToFirstToken == 0 && s.config.PrefillOverhead != 0 {
659+
if nPromptTokens <= 1 {
660+
return s.config.PrefillOverhead
661+
}
662+
return s.calcPrefillOverhead(nPromptTokens)
663+
}
664+
656665
mean := float64(s.config.TimeToFirstToken)
657666
stddev := float64(s.config.TimeToFirstTokenStdDev)
658667
if doRemotePrefill {
@@ -678,6 +687,22 @@ func (s *VllmSimulator) getTotalInterTokenLatency(numOfTokens int) int {
678687
return total
679688
}
680689

690+
// calc the prefill overhead against number of tokens
691+
func (s *VllmSimulator) calcPrefillOverhead(nPromptTokens int) int {
692+
pfOverhead := s.config.PrefillOverhead
693+
complexity := s.config.PrefillOverheadComplexity
694+
// policies of different complexities of prefill implementation
695+
switch complexity {
696+
case "n^2", "":
697+
// this is simple implementation of n^2
698+
return pfOverhead * nPromptTokens * nPromptTokens
699+
case "nlog(n)":
700+
return int(float64(pfOverhead) * (float64(nPromptTokens) * math.Log2(float64(nPromptTokens))))
701+
}
702+
703+
return 0
704+
}
705+
681706
// createModelsResponse creates and returns ModelResponse for the current state, returned array of models contains the base model + LoRA adapters if exist
682707
func (s *VllmSimulator) createModelsResponse() *vllmapi.ModelsResponse {
683708
modelsResp := vllmapi.ModelsResponse{Object: "list", Data: []vllmapi.ModelsResponseModelInfo{}}

pkg/llm-d-inference-sim/simulator_test.go

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"errors"
2222
"fmt"
2323
"io"
24+
"math"
2425
"net"
2526
"net/http"
2627
"os"
@@ -798,7 +799,7 @@ var _ = Describe("Simulator", func() {
798799
simulator.config.TimeToFirstTokenStdDev = timeToFirstTokenStdDev
799800
simulator.config.KVCacheTransferLatency = kvCacheLatency
800801
simulator.config.KVCacheTransferLatencyStdDev = kvCacheLatencyStdDev
801-
timeToFirst := simulator.getTimeToFirstToken(doREmotePrefill)
802+
timeToFirst := simulator.getTimeToFirstToken(doREmotePrefill, 1)
802803
if doREmotePrefill {
803804
Expect(timeToFirst).To(BeNumerically(">=", int(float32(kvCacheLatency)*0.3)))
804805
Expect(timeToFirst).To(BeNumerically("<=", int(float32(kvCacheLatency)*1.7)))
@@ -819,6 +820,63 @@ var _ = Describe("Simulator", func() {
819820
Entry(nil, 10000, 0, 1000, 0, true),
820821
Entry(nil, 10000, 0, 1000, 0, false),
821822
)
823+
824+
It("when <time-to-first-token> is not 0, ignore <prefill-overhead>", func() {
825+
timeToFirstToken := 10000
826+
prefillOverhead := 100
827+
simulator.config.TimeToFirstToken = timeToFirstToken
828+
simulator.config.PrefillOverhead = prefillOverhead
829+
timeToFirst := simulator.getTimeToFirstToken(false, 1)
830+
Expect(timeToFirst).To(BeNumerically(">=", int(float32(timeToFirstToken)*0.3)))
831+
Expect(timeToFirst).To(BeNumerically("<=", int(float32(timeToFirstToken)*1.7)))
832+
})
833+
834+
It("when <time-to-first-token> is 0, use <prefill-overhead>", func() {
835+
simulator.config.TimeToFirstToken = 0
836+
simulator.config.PrefillOverhead = 100
837+
timeToFirst := simulator.getTimeToFirstToken(false, 1)
838+
Expect(timeToFirst).To(BeNumerically(">=", 100))
839+
})
840+
841+
DescribeTable("time to first token is super linear of prefill against number of prompt tokens",
842+
func(prefillOverhead int, tolerance float64, minNTokens int, maxNTokens int) {
843+
for nTokens := minNTokens; nTokens <= maxNTokens; nTokens++ {
844+
square := prefillOverhead * nTokens * nTokens
845+
simulator.config.PrefillOverhead = prefillOverhead
846+
timeToFirst := simulator.getTimeToFirstToken(false, nTokens)
847+
diffRatio := math.Abs(float64(timeToFirst-square)) / float64(square)
848+
Expect(diffRatio).To(BeNumerically("<", tolerance))
849+
}
850+
},
851+
func(prefillOverhead int, tolerance float64, minNTokens int, maxNTokens int) string {
852+
return fmt.Sprintf("prefillOverhead: %d tolerance: %f minNTokens: %d maxNTokens: %d",
853+
prefillOverhead, tolerance, minNTokens, maxNTokens)
854+
},
855+
Entry("small numbers", 100, 0.1, 1, 10),
856+
Entry("medium numbers, larger range", 200, 0.1, 50, 100),
857+
Entry("large numbers", 150, 0.05, 20000, 20010),
858+
)
859+
860+
DescribeTable("time to first token is log-linear of prefill against number of prompt tokens",
861+
func(prefillOverhead int, tolerance float64, minNTokens int, maxNTokens int) {
862+
simulator.config.PrefillOverheadComplexity = "nlog(n)"
863+
864+
for nTokens := minNTokens; nTokens <= maxNTokens; nTokens++ {
865+
nlogn := int(float64(prefillOverhead) * float64(nTokens) * math.Log2(float64(nTokens)))
866+
simulator.config.PrefillOverhead = prefillOverhead
867+
timeToFirst := simulator.getTimeToFirstToken(false, nTokens)
868+
diffRatio := math.Abs(float64(timeToFirst-nlogn)) / float64(nlogn)
869+
Expect(diffRatio).To(BeNumerically("<", tolerance))
870+
}
871+
},
872+
func(prefillOverhead int, tolerance float64, minNTokens int, maxNTokens int) string {
873+
return fmt.Sprintf("prefillOverhead: %d tolerance: %f minNTokens: %d maxNTokens: %d",
874+
prefillOverhead, tolerance, minNTokens, maxNTokens)
875+
},
876+
Entry("small numbers", 100, 0.1, 2, 10),
877+
Entry("medium numbers, larger range", 200, 0.1, 50, 100),
878+
Entry("large numbers", 150, 0.05, 20000, 20010),
879+
)
822880
})
823881

824882
Context("fake metrics", func() {

pkg/llm-d-inference-sim/streaming.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ type streamingContext struct {
3939
// as defined by isChatCompletion
4040
// response content is wrapped according SSE format
4141
// First token is send after timeToFirstToken milliseconds, every other token is sent after interTokenLatency milliseconds
42-
func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, responseTokens []string, toolCalls []openaiserverapi.ToolCall,
42+
func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, nPromptTokens int, responseTokens []string, toolCalls []openaiserverapi.ToolCall,
4343
finishReason string, usageData *openaiserverapi.Usage) {
4444
context.ctx.SetContentType("text/event-stream")
4545
context.ctx.SetStatusCode(fasthttp.StatusOK)
@@ -67,11 +67,11 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, respons
6767
if len(toolCalls) > 0 {
6868
s.logger.Info("Going to send tools calls")
6969
for _, tc := range toolCalls {
70-
s.sendTokenChunks(context, w, tc.Function.TokenizedArguments, &tc, finishReason)
70+
s.sendTokenChunks(context, w, nPromptTokens, tc.Function.TokenizedArguments, &tc, finishReason)
7171
}
7272
} else {
7373
s.logger.Info("Going to send text", "number of tokens", len(responseTokens))
74-
s.sendTokenChunks(context, w, responseTokens, nil, finishReason)
74+
s.sendTokenChunks(context, w, nPromptTokens, responseTokens, nil, finishReason)
7575
}
7676
}
7777

@@ -94,11 +94,11 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, respons
9494
}
9595

9696
// sendTokenChunks creates and sends response chunks
97-
func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writer, tokens []string, tc *openaiserverapi.ToolCall, finishReason string) {
97+
func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writer, nPromptTokens int, genTokens []string, tc *openaiserverapi.ToolCall, finishReason string) {
9898
// time to first token delay
99-
time.Sleep(time.Duration(s.getTimeToFirstToken(context.doRemotePrefill)) * time.Millisecond)
99+
time.Sleep(time.Duration(s.getTimeToFirstToken(context.doRemotePrefill, nPromptTokens)) * time.Millisecond)
100100

101-
for i, token := range tokens {
101+
for i, token := range genTokens {
102102
if i != 0 {
103103
time.Sleep(time.Duration(s.getInterTokenLatency()) * time.Millisecond)
104104
}
@@ -119,7 +119,7 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ
119119

120120
var chunk openaiserverapi.CompletionRespChunk
121121
var finishReasonToSend *string
122-
if i == len(tokens)-1 && (finishReason == common.LengthFinishReason || finishReason == common.ToolsFinishReason) {
122+
if i == len(genTokens)-1 && (finishReason == common.LengthFinishReason || finishReason == common.ToolsFinishReason) {
123123
finishReasonToSend = &finishReason
124124
}
125125
if context.isChatCompletion {

0 commit comments

Comments
 (0)