diff --git a/README.md b/README.md index 8b1bd80c..adf7c48a 100644 --- a/README.md +++ b/README.md @@ -101,13 +101,22 @@ For more details see the 0.3*float32(c.TimeToFirstToken) { return errors.New("time to first token standard deviation cannot be more than 30% of time to first token") } + + if c.PrefillOverhead < 0 { + return errors.New("prefill overhead cannot be negative") + } + if c.PrefillTimePerToken < 0 { + return errors.New("prefill time per token cannot be negative") + } + if c.PrefillTimeStdDev < 0 { + return errors.New("prefill time standard deviation cannot be negative") + } + + if c.KVCacheTransferTimePerToken < 0 { + return errors.New("kv-cache tranfer time per token cannot be negative") + } + if c.KVCacheTransferTimeStdDev < 0 { + return errors.New("kv-cache tranfer time standard deviation cannot be negative") + } + if c.KVCacheTransferLatency < 0 { return errors.New("kv-cache tranfer time cannot be negative") } @@ -316,6 +350,7 @@ func (c *Configuration) validate() error { if float32(c.KVCacheTransferLatencyStdDev) > 0.3*float32(c.KVCacheTransferLatency) { return errors.New("kv-cache tranfer standard deviation cannot be more than 30% of kv-cache tranfer") } + if c.MaxLoras < 1 { return errors.New("max LoRAs cannot be less than 1") } @@ -433,6 +468,13 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode: echo - returns the same text that was sent in the request, for chat completion returns the last message; random - returns random sentence from a bank of pre-defined sentences") f.IntVar(&config.InterTokenLatency, "inter-token-latency", config.InterTokenLatency, "Time to generate one token (in milliseconds)") f.IntVar(&config.TimeToFirstToken, "time-to-first-token", config.TimeToFirstToken, "Time to first token (in milliseconds)") + + f.IntVar(&config.PrefillOverhead, "prefill-overhead", config.PrefillOverhead, "Time to prefill in milliseconds. This argument is ignored if is not 0.") + f.IntVar(&config.PrefillTimePerToken, "prefill-time-per-token", config.PrefillTimePerToken, "Time to prefill per token (in milliseconds)") + f.IntVar(&config.PrefillTimeStdDev, "prefill-time-std-dev", config.PrefillTimeStdDev, "Standard deviation for time to prefill (in milliseconds)") + f.IntVar(&config.KVCacheTransferTimePerToken, "kv-cache-transfer-time-per-token", config.KVCacheTransferTimePerToken, "Time for KV-cache transfer per token from a remote vLLM (in milliseconds)") + f.IntVar(&config.KVCacheTransferTimeStdDev, "kv-cache-transfer-time-std-dev", config.KVCacheTransferTimeStdDev, "Standard deviation for time for KV-cache transfer per token from a remote vLLM (in milliseconds)") + f.IntVar(&config.KVCacheTransferLatency, "kv-cache-transfer-latency", config.KVCacheTransferLatency, "Time for KV-cache transfer from a remote vLLM (in milliseconds)") f.IntVar(&config.InterTokenLatencyStdDev, "inter-token-latency-std-dev", config.InterTokenLatencyStdDev, "Standard deviation for time between generated tokens (in milliseconds)") f.IntVar(&config.TimeToFirstTokenStdDev, "time-to-first-token-std-dev", config.TimeToFirstTokenStdDev, "Standard deviation for time before the first token will be returned (in milliseconds)") diff --git a/pkg/common/config_test.go b/pkg/common/config_test.go index 770716a6..7d5fae13 100644 --- a/pkg/common/config_test.go +++ b/pkg/common/config_test.go @@ -401,6 +401,31 @@ var _ = Describe("Simulator configuration", func() { name: "invalid (negative) zmq-max-connect-attempts for config file", args: []string{"cmd", "--config", "../../manifests/invalid-config.yaml"}, }, + { + name: "invalid (negative) prefill-overhead", + args: []string{"cmd", "--prefill-overhead", "-1", + "--config", "../../manifests/config.yaml"}, + }, + { + name: "invalid (negative) prefill-time-per-token", + args: []string{"cmd", "--prefill-time-per-token", "-1", + "--config", "../../manifests/config.yaml"}, + }, + { + name: "invalid (negative) prefill-time-std-dev", + args: []string{"cmd", "--prefill-time-std-dev", "-1", + "--config", "../../manifests/config.yaml"}, + }, + { + name: "invalid (negative) kv-cache-transfer-time-per-token", + args: []string{"cmd", "--kv-cache-transfer-time-per-token", "-1", + "--config", "../../manifests/config.yaml"}, + }, + { + name: "invalid (negative) kv-cache-transfer-time-std-dev", + args: []string{"cmd", "--kv-cache-transfer-time-std-dev", "-1", + "--config", "../../manifests/config.yaml"}, + }, } for _, test := range invalidTests { diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index e291f15d..2ecade9d 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -495,7 +495,7 @@ func (s *VllmSimulator) reqProcessingWorker(ctx context.Context, id int) { model: displayModel, doRemotePrefill: req.IsDoRemotePrefill(), }, - responseTokens, toolCalls, finishReason, usageDataToSend, + usageDataToSend.PromptTokens, responseTokens, toolCalls, finishReason, usageDataToSend, ) } else { if req.IsDoRemoteDecode() { @@ -646,8 +646,9 @@ func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.Reques } // calculate how long to wait before returning the response, time is based on number of tokens - numOfTokens := usageData.CompletionTokens - totalMillisToWait := s.getTimeToFirstToken(doRemotePrefill) + s.getTotalInterTokenLatency(numOfTokens) + nPromptTokens := usageData.PromptTokens + nGenTokens := usageData.CompletionTokens + totalMillisToWait := s.getTimeToFirstToken(nPromptTokens, doRemotePrefill) + s.getTotalInterTokenLatency(nGenTokens) time.Sleep(time.Duration(totalMillisToWait) * time.Millisecond) ctx.Response.Header.SetContentType("application/json") @@ -665,14 +666,23 @@ func (s *VllmSimulator) sendResponse(isChatCompletion bool, ctx *fasthttp.Reques } // returns time to first token based on the current request's doRemotePrefill -func (s *VllmSimulator) getTimeToFirstToken(doRemotePrefill bool) int { - mean := float64(s.config.TimeToFirstToken) - stddev := float64(s.config.TimeToFirstTokenStdDev) +func (s *VllmSimulator) getTimeToFirstToken(nPromptTokens int, doRemotePrefill bool) int { if doRemotePrefill { - mean = float64(s.config.KVCacheTransferLatency) - stddev = float64(s.config.KVCacheTransferLatencyStdDev) + if s.config.KVCacheTransferLatency == 0 && s.config.KVCacheTransferLatencyStdDev == 0 { + // is disaggregated PD and ttft is calculated using number of prompt tokens + kvCacheTransT := s.config.KVCacheTransferTimePerToken * nPromptTokens + return int(common.RandomNorm(float64(kvCacheTransT), float64(s.config.KVCacheTransferTimeStdDev))) + } + // is disaggregated PD and *not* using number of prompt tokens + return int(common.RandomNorm(float64(s.config.KVCacheTransferLatency), float64(s.config.KVCacheTransferLatencyStdDev))) } - return int(common.RandomNorm(mean, stddev)) + if s.config.TimeToFirstToken == 0 && s.config.TimeToFirstTokenStdDev == 0 { + // is aggregated PD and ttft is calculated using number of prompt tokens + prefillTime := s.config.PrefillOverhead + nPromptTokens*s.config.PrefillTimePerToken + return int(common.RandomNorm(float64(prefillTime), float64(s.config.PrefillTimeStdDev))) + } + // is aggregated PD and *not* using number of prompt tokens + return int(common.RandomNorm(float64(s.config.TimeToFirstToken), float64(s.config.TimeToFirstTokenStdDev))) } // returns inter token latency diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 9e4c882b..73699fbb 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -807,7 +807,7 @@ var _ = Describe("Simulator", func() { simulator.config.TimeToFirstTokenStdDev = timeToFirstTokenStdDev simulator.config.KVCacheTransferLatency = kvCacheLatency simulator.config.KVCacheTransferLatencyStdDev = kvCacheLatencyStdDev - timeToFirst := simulator.getTimeToFirstToken(doREmotePrefill) + timeToFirst := simulator.getTimeToFirstToken(1, doREmotePrefill) if doREmotePrefill { Expect(timeToFirst).To(BeNumerically(">=", int(float32(kvCacheLatency)*0.3))) Expect(timeToFirst).To(BeNumerically("<=", int(float32(kvCacheLatency)*1.7))) @@ -828,5 +828,104 @@ var _ = Describe("Simulator", func() { Entry(nil, 10000, 0, 1000, 0, true), Entry(nil, 10000, 0, 1000, 0, false), ) + + It("when is not 0, ignore ", func() { + timeToFirstToken := 1000 + simulator.config.TimeToFirstToken = timeToFirstToken + simulator.config.TimeToFirstTokenStdDev = 0 + + simulator.config.PrefillOverhead = 100 + simulator.config.PrefillTimePerToken = 200 + simulator.config.PrefillTimeStdDev = 80 + + ttft := simulator.getTimeToFirstToken(128, false) + + Expect(ttft).To(BeNumerically("==", timeToFirstToken)) + }) + + It("when is 0, and is not 0, use ", func() { + simulator.config.TimeToFirstToken = 0 + simulator.config.TimeToFirstTokenStdDev = 0 + + simulator.config.PrefillOverhead = 100 + simulator.config.PrefillTimePerToken = 200 + simulator.config.PrefillTimeStdDev = 80 + + ttft := simulator.getTimeToFirstToken(128, false) + Expect(ttft).NotTo(BeNumerically("==", 0)) + }) + + DescribeTable("time to first token is against number of prompt tokens", + func(prefillOverhead int, prefillTimePerToken int, stdDev int, nTokens int) { + simulator.config.TimeToFirstToken = 0 + simulator.config.PrefillOverhead = prefillOverhead + simulator.config.PrefillTimePerToken = prefillTimePerToken + simulator.config.PrefillTimeStdDev = stdDev + + ttft := simulator.getTimeToFirstToken(nTokens, false) + + expectedTTFT := prefillOverhead + prefillTimePerToken*nTokens + Expect(ttft).To(BeNumerically(">=", int(float64(expectedTTFT)*0.3))) + Expect(ttft).To(BeNumerically("<=", int(float64(expectedTTFT)*1.7))) + + }, + func(prefillOverhead int, prefillTimePerToken, stdDev int, nTokens int) string { + return fmt.Sprintf("prefillOverhead: %d, prefillTimePerToken: %d, stdDev: %d, nTokens: %d", + prefillOverhead, prefillTimePerToken, stdDev, nTokens) + }, + Entry("single token", 100, 50, 70, 1), + Entry("stddev is 0", 100, 50, 0, 1), + Entry("medium overhead, 512 tokens", 200, 1000, 150, 512), + Entry("large overhead, 1024 tokens", 2000, 3000, 1800, 1024), + Entry("very long prompt", 150, 200, 100, 20000), + ) + + It("when not 0, ignore ", func() { + simulator.config.KVCacheTransferLatency = 200 + simulator.config.KVCacheTransferLatencyStdDev = 0 + + simulator.config.KVCacheTransferTimePerToken = 100 + simulator.config.KVCacheTransferTimeStdDev = 0 + + ttft := simulator.getTimeToFirstToken(128, true) + Expect(ttft).To(BeNumerically("==", 200)) + }) + + It("when is 0, and is not 0, use ", func() { + simulator.config.KVCacheTransferLatency = 0 + simulator.config.KVCacheTransferLatencyStdDev = 0 + + simulator.config.KVCacheTransferTimePerToken = 100 + simulator.config.KVCacheTransferTimeStdDev = 0 + + ttft := simulator.getTimeToFirstToken(128, true) + Expect(ttft).To(BeNumerically("==", 12800)) + }) + + DescribeTable("kv cache transfer time against number of prompt tokens", + func(kvCacheTransTPT int, stddev int, nTokens int) { + simulator.config.TimeToFirstToken = 0 + simulator.config.PrefillOverhead = 1 + simulator.config.KVCacheTransferTimePerToken = kvCacheTransTPT + simulator.config.KVCacheTransferTimeStdDev = stddev + + ttft := simulator.getTimeToFirstToken(nTokens, true) + + expectedTTFT := kvCacheTransTPT * nTokens + Expect(ttft).To(BeNumerically(">=", int(float64(expectedTTFT)*0.3))) + Expect(ttft).To(BeNumerically("<=", int(float64(expectedTTFT)*1.7))) + + }, + func(kvCacheTransferTimePerToken int, stddev int, nTokens int) string { + return fmt.Sprintf("kvCacheTransferTimePerToken: %d stddev: %d nTokens: %d", + kvCacheTransferTimePerToken, stddev, nTokens) + }, + Entry("single token", 100, 70, 1), + Entry("stddev is 0", 100, 0, 1), + Entry("medium overhead, 512 tokens", 200, 150, 512), + Entry("large overhead, 1024 tokens", 2000, 1800, 1024), + Entry("very long prompt", 150, 100, 20000), + ) + }) }) diff --git a/pkg/llm-d-inference-sim/streaming.go b/pkg/llm-d-inference-sim/streaming.go index 969f29af..d234114a 100644 --- a/pkg/llm-d-inference-sim/streaming.go +++ b/pkg/llm-d-inference-sim/streaming.go @@ -39,7 +39,7 @@ type streamingContext struct { // as defined by isChatCompletion // response content is wrapped according SSE format // First token is send after timeToFirstToken milliseconds, every other token is sent after interTokenLatency milliseconds -func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, responseTokens []string, toolCalls []openaiserverapi.ToolCall, +func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, nPromptTokens int, responseTokens []string, toolCalls []openaiserverapi.ToolCall, finishReason string, usageData *openaiserverapi.Usage) { context.ctx.SetContentType("text/event-stream") context.ctx.SetStatusCode(fasthttp.StatusOK) @@ -67,11 +67,11 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, respons if len(toolCalls) > 0 { s.logger.Info("Going to send tools calls") for _, tc := range toolCalls { - s.sendTokenChunks(context, w, tc.Function.TokenizedArguments, &tc, finishReason) + s.sendTokenChunks(context, w, nPromptTokens, tc.Function.TokenizedArguments, &tc, finishReason) } } else { s.logger.Info("Going to send text", "number of tokens", len(responseTokens)) - s.sendTokenChunks(context, w, responseTokens, nil, finishReason) + s.sendTokenChunks(context, w, nPromptTokens, responseTokens, nil, finishReason) } } @@ -94,11 +94,11 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, respons } // sendTokenChunks creates and sends response chunks -func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writer, tokens []string, tc *openaiserverapi.ToolCall, finishReason string) { +func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writer, nPromptTokens int, genTokens []string, tc *openaiserverapi.ToolCall, finishReason string) { // time to first token delay - time.Sleep(time.Duration(s.getTimeToFirstToken(context.doRemotePrefill)) * time.Millisecond) + time.Sleep(time.Duration(s.getTimeToFirstToken(nPromptTokens, context.doRemotePrefill)) * time.Millisecond) - for i, token := range tokens { + for i, token := range genTokens { if i != 0 { time.Sleep(time.Duration(s.getInterTokenLatency()) * time.Millisecond) } @@ -119,7 +119,7 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ var chunk openaiserverapi.CompletionRespChunk var finishReasonToSend *string - if i == len(tokens)-1 && (finishReason == common.LengthFinishReason || finishReason == common.ToolsFinishReason) { + if i == len(genTokens)-1 && (finishReason == common.LengthFinishReason || finishReason == common.ToolsFinishReason) { finishReasonToSend = &finishReason } if context.isChatCompletion { diff --git a/pkg/openai-server-api/request.go b/pkg/openai-server-api/request.go index d368a211..b23104f8 100644 --- a/pkg/openai-server-api/request.go +++ b/pkg/openai-server-api/request.go @@ -53,9 +53,13 @@ type CompletionRequest interface { GetToolChoice() string // GetMaxCompletionTokens returns the maximum completion tokens requested GetMaxCompletionTokens() *int64 - // IsDoRemoteDecode() returns true if do_remote_decode field is true in the request, this means that this is prefill request + // IsDoRemoteDecode() returns true if do_remote_decode field is true in the request, + // when the field is true, the decode phase should be done on remote pod, + // whereas prefill phase is done on local pod, thus this is a prefill request IsDoRemoteDecode() bool - // IsDoRemotePrefill() returns true if do_remote_prefill field is true in the request, this means that this is decode request + // IsDoRemotePrefill() returns true if do_remote_prefill field is true in the request, + // when the field is true, the prefill phase should be done on remote pod, + // whereas decode phase is done on local pod, thus this is a decode request IsDoRemotePrefill() bool }