diff --git a/pkg/llm-d-inference-sim/failures_test.go b/pkg/llm-d-inference-sim/failures_test.go index 24ebf0a8..36d324a0 100644 --- a/pkg/llm-d-inference-sim/failures_test.go +++ b/pkg/llm-d-inference-sim/failures_test.go @@ -26,7 +26,6 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/openai/openai-go" - "github.com/openai/openai-go/option" "github.com/llm-d/llm-d-inference-sim/pkg/common" openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" @@ -135,18 +134,8 @@ var _ = Describe("Failures", func() { }) It("should always return an error response for chat completions", func() { - openaiClient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client), - ) - - _, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ - Model: model, - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - }) - + openaiClient, params := getOpenAIClentAndChatParams(client, model, userMessage, false) + _, err := openaiClient.Chat.Completions.New(ctx, params) Expect(err).To(HaveOccurred()) var openaiError *openai.Error @@ -158,18 +147,8 @@ var _ = Describe("Failures", func() { }) It("should always return an error response for text completions", func() { - openaiClient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client), - ) - - _, err := openaiClient.Completions.New(ctx, openai.CompletionNewParams{ - Model: openai.CompletionNewParamsModel(model), - Prompt: openai.CompletionNewParamsPromptUnion{ - OfString: openai.String(userMessage), - }, - }) - + openaiClient, params := getOpenAIClentAndChatParams(client, model, userMessage, false) + _, err := openaiClient.Chat.Completions.New(ctx, params) Expect(err).To(HaveOccurred()) var openaiError *openai.Error @@ -194,18 +173,8 @@ var _ = Describe("Failures", func() { }) It("should return only rate limit errors", func() { - openaiClient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client), - ) - - _, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ - Model: model, - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - }) - + openaiClient, params := getOpenAIClentAndChatParams(client, model, userMessage, false) + _, err := openaiClient.Chat.Completions.New(ctx, params) Expect(err).To(HaveOccurred()) var openaiError *openai.Error @@ -230,20 +199,11 @@ var _ = Describe("Failures", func() { }) It("should return only specified error types", func() { - openaiClient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client), - ) + openaiClient, params := getOpenAIClentAndChatParams(client, model, userMessage, false) // Make multiple requests to verify we get the expected error types for i := 0; i < 10; i++ { - _, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ - Model: model, - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - }) - + _, err := openaiClient.Chat.Completions.New(ctx, params) Expect(err).To(HaveOccurred()) var openaiError *openai.Error @@ -270,18 +230,8 @@ var _ = Describe("Failures", func() { }) It("should never return errors and behave like random mode", func() { - openaiClient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client), - ) - - resp, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ - Model: model, - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - }) - + openaiClient, params := getOpenAIClentAndChatParams(client, model, userMessage, false) + resp, err := openaiClient.Chat.Completions.New(ctx, params) Expect(err).ToNot(HaveOccurred()) Expect(resp.Choices).To(HaveLen(1)) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) @@ -300,18 +250,8 @@ var _ = Describe("Failures", func() { }, nil) Expect(err).ToNot(HaveOccurred()) - openaiClient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client), - ) - - _, err = openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ - Model: model, - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - }) - + openaiClient, params := getOpenAIClentAndChatParams(client, model, userMessage, false) + _, err = openaiClient.Chat.Completions.New(ctx, params) Expect(err).To(HaveOccurred()) var openaiError *openai.Error diff --git a/pkg/llm-d-inference-sim/helpers.go b/pkg/llm-d-inference-sim/helpers.go new file mode 100644 index 00000000..60089da7 --- /dev/null +++ b/pkg/llm-d-inference-sim/helpers.go @@ -0,0 +1,94 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package vllmsim implements the vLLM simulator. +package llmdinferencesim + +import ( + "encoding/json" + "fmt" +) + +// isValidModel checks if the given model is the base model or one of "loaded" LoRAs +func (s *VllmSimulator) isValidModel(model string) bool { + for _, name := range s.config.ServedModelNames { + if model == name { + return true + } + } + for _, lora := range s.getLoras() { + if model == lora { + return true + } + } + + return false +} + +// isLora returns true if the given model name is one of loaded LoRAs +func (s *VllmSimulator) isLora(model string) bool { + for _, lora := range s.getLoras() { + if model == lora { + return true + } + } + + return false +} + +// getDisplayedModelName returns the model name that must appear in API +// responses. LoRA adapters keep their explicit name, while all base-model +// requests are surfaced as the first alias from --served-model-name. +func (s *VllmSimulator) getDisplayedModelName(reqModel string) string { + if s.isLora(reqModel) { + return reqModel + } + return s.config.ServedModelNames[0] +} + +func (s *VllmSimulator) showConfig(dp bool) error { + cfgJSON, err := json.Marshal(s.config) + if err != nil { + return fmt.Errorf("failed to marshal configuration to JSON: %w", err) + } + + var m map[string]interface{} + err = json.Unmarshal(cfgJSON, &m) + if err != nil { + return fmt.Errorf("failed to unmarshal JSON to map: %w", err) + } + if dp { + // remove the port + delete(m, "port") + } + // clean LoraModulesString field + m["lora-modules"] = m["LoraModules"] + delete(m, "LoraModules") + delete(m, "LoraModulesString") + + // clean fake-metrics field + if field, ok := m["fake-metrics"].(map[string]interface{}); ok { + delete(field, "LorasString") + } + + // show in JSON + cfgJSON, err = json.MarshalIndent(m, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal configuration to JSON: %w", err) + } + s.logger.Info("Configuration:", "", string(cfgJSON)) + return nil +} diff --git a/pkg/llm-d-inference-sim/latencies.go b/pkg/llm-d-inference-sim/latencies.go new file mode 100644 index 00000000..765e362d --- /dev/null +++ b/pkg/llm-d-inference-sim/latencies.go @@ -0,0 +1,65 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package vllmsim implements the vLLM simulator. +package llmdinferencesim + +import "github.com/llm-d/llm-d-inference-sim/pkg/common" + +func (s *VllmSimulator) getCurrLoadFactor() float64 { + if s.config.MaxNumSeqs <= 1 { + return 1.0 + } + return 1 + (s.config.TimeFactorUnderLoad-1)*float64(s.nRunningReqs-1)/float64(s.config.MaxNumSeqs-1) +} + +func (s *VllmSimulator) getTimeToFirstToken() int { + return int(float64(s.config.TimeToFirstToken) * s.getCurrLoadFactor()) +} + +func (s *VllmSimulator) getPrefillOverhead() int { + return int(float64(s.config.PrefillOverhead) * s.getCurrLoadFactor()) +} + +func (s *VllmSimulator) getPrefillTimePerToken() int { + return int(float64(s.config.PrefillTimePerToken) * s.getCurrLoadFactor()) +} + +// returns time to first token based on the current request's doRemotePrefill +func (s *VllmSimulator) getWaitTimeToFirstToken(nPromptTokens int, nCachedPromptTokens int, doRemotePrefill bool) int { + if doRemotePrefill { + if s.config.KVCacheTransferLatency == 0 && s.config.KVCacheTransferLatencyStdDev == 0 { + // is disaggregated PD and ttft is calculated using number of prompt tokens + kvCacheTransT := s.config.KVCacheTransferTimePerToken * nPromptTokens + return common.RandomNorm(kvCacheTransT, s.config.KVCacheTransferTimeStdDev) + } + // is disaggregated PD and *not* using number of prompt tokens + return common.RandomNorm(s.config.KVCacheTransferLatency, s.config.KVCacheTransferLatencyStdDev) + } + if s.config.TimeToFirstToken == 0 && s.config.TimeToFirstTokenStdDev == 0 { + // is aggregated PD and ttft is calculated using number of prompt tokens that are not in kv cache + prefillTime := s.getPrefillOverhead() + (nPromptTokens-nCachedPromptTokens)*s.getPrefillTimePerToken() + return common.RandomNorm(prefillTime, s.config.PrefillTimeStdDev) + } + // is aggregated PD and *not* using number of prompt tokens + return common.RandomNorm(s.getTimeToFirstToken(), s.config.TimeToFirstTokenStdDev) +} + +// returns inter token latency +func (s *VllmSimulator) getInterTokenLatency() int { + latency := int(float64(s.config.InterTokenLatency) * s.getCurrLoadFactor()) + return common.RandomNorm(latency, s.config.InterTokenLatencyStdDev) +} diff --git a/pkg/llm-d-inference-sim/latencies_test.go b/pkg/llm-d-inference-sim/latencies_test.go new file mode 100644 index 00000000..5eddfecb --- /dev/null +++ b/pkg/llm-d-inference-sim/latencies_test.go @@ -0,0 +1,346 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llmdinferencesim + +import ( + "fmt" + "time" + + "github.com/llm-d/llm-d-inference-sim/pkg/common" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "k8s.io/klog/v2" +) + +var _ = Describe("Check random latencies", Ordered, func() { + var simulator *VllmSimulator + + BeforeAll(func() { + var err error + simulator, err = New(klog.Background()) + Expect(err).NotTo(HaveOccurred()) + + simulator.config = &common.Configuration{ + TimeToFirstToken: 2048, + TimeToFirstTokenStdDev: 2048, + KVCacheTransferLatency: 2048, + KVCacheTransferLatencyStdDev: 2048, + } + + common.InitRandom(time.Now().UnixNano()) + }) + + DescribeTable("should calculate inter token latency correctly", + func(interTokenLatency int, stddev int) { + simulator.config.InterTokenLatency = interTokenLatency + simulator.config.InterTokenLatencyStdDev = stddev + interToken := simulator.getInterTokenLatency() + Expect(interToken).To(BeNumerically(">=", int(float32(interTokenLatency)*0.3))) + Expect(interToken).To(BeNumerically("<=", int(float32(interTokenLatency)*1.7))) + }, + func(interTokenLatency int, stddev int) string { + return fmt.Sprintf("interTokenLatency: %d stddev: %d", interTokenLatency, stddev) + }, + Entry(nil, 1000, 300), + Entry(nil, 1000, 800), // invalid std dev, used for testing purposes + Entry(nil, 1000, 900), // invalid std dev, used for testing purposes + Entry(nil, 1000, 0), + ) + + DescribeTable("should calculate total inter token latency correctly", + func(interTokenLatency int, stddev int, numberOfTokens int) { + simulator.config.InterTokenLatency = interTokenLatency + simulator.config.InterTokenLatencyStdDev = stddev + simulator.config.MaxNumSeqs = 1 + simulator.config.TimeFactorUnderLoad = 1.0 + + latency := 0 + for range numberOfTokens - 1 { + latency += simulator.getInterTokenLatency() + } + + Expect(latency).To(BeNumerically(">=", int(float32(interTokenLatency)*0.3*float32(numberOfTokens)))) + Expect(latency).To(BeNumerically("<=", int(float32(interTokenLatency)*1.7*float32(numberOfTokens)))) + }, + func(interTokenLatency int, stddev int, numberOfTokens int) string { + return fmt.Sprintf("interTokenLatency: %d stddev: %d, numberOfTokens: %d", interTokenLatency, + stddev, numberOfTokens) + }, + Entry(nil, 1000, 30, 100), + Entry(nil, 1000, 800, 20), // invalid std dev, used for testing purposes + Entry(nil, 1000, 900, 5), // invalid std dev, used for testing purposes + Entry(nil, 1000, 0, 50), + ) + + DescribeTable("should calculate time to first token correctly", + func(timeToFirstToken int, timeToFirstTokenStdDev int, + kvCacheLatency int, kvCacheLatencyStdDev int, doREmotePrefill bool) { + simulator.config.TimeToFirstToken = timeToFirstToken + simulator.config.TimeToFirstTokenStdDev = timeToFirstTokenStdDev + simulator.config.KVCacheTransferLatency = kvCacheLatency + simulator.config.KVCacheTransferLatencyStdDev = kvCacheLatencyStdDev + timeToFirst := simulator.getWaitTimeToFirstToken(1, 0, doREmotePrefill) + if doREmotePrefill { + Expect(timeToFirst).To(BeNumerically(">=", int(float32(kvCacheLatency)*0.3))) + Expect(timeToFirst).To(BeNumerically("<=", int(float32(kvCacheLatency)*1.7))) + } else { + Expect(timeToFirst).To(BeNumerically(">=", int(float32(timeToFirstToken)*0.3))) + Expect(timeToFirst).To(BeNumerically("<=", int(float32(timeToFirstToken)*1.7))) + } + }, + func(timeToFirstToken int, timeToFirstTokenStdDev int, + kvCacheLatency int, kvCacheLatencyStdDev int, doREmotePrefill bool) string { + return fmt.Sprintf("timeToFirstToken: %d stddev: %d kvCacheLatency: %d stddev: %d doREmotePrefill: %t", + timeToFirstToken, timeToFirstTokenStdDev, kvCacheLatency, kvCacheLatencyStdDev, doREmotePrefill) + }, + Entry(nil, 10000, 300, 1000, 200, true), + Entry(nil, 10000, 300, 1000, 200, false), + Entry(nil, 10000, 9000, 1000, 800, true), // invalid std dev, used for testing purposes + Entry(nil, 10000, 8000, 1000, 900, false), // invalid std dev, used for testing purposes + Entry(nil, 10000, 0, 1000, 0, true), + Entry(nil, 10000, 0, 1000, 0, false), + ) + + It("when is not 0, ignore ", func() { + timeToFirstToken := 1000 + simulator.config.TimeToFirstToken = timeToFirstToken + simulator.config.TimeToFirstTokenStdDev = 0 + + simulator.config.PrefillOverhead = 100 + simulator.config.PrefillTimePerToken = 200 + simulator.config.PrefillTimeStdDev = 80 + + ttft := simulator.getWaitTimeToFirstToken(128, 0, false) + + Expect(ttft).To(BeNumerically("==", timeToFirstToken)) + }) + + It("when is 0, and is not 0, use ", func() { + simulator.config.TimeToFirstToken = 0 + simulator.config.TimeToFirstTokenStdDev = 0 + + simulator.config.PrefillOverhead = 100 + simulator.config.PrefillTimePerToken = 200 + simulator.config.PrefillTimeStdDev = 80 + + ttft := simulator.getWaitTimeToFirstToken(128, 0, false) + Expect(ttft).NotTo(BeNumerically("==", 0)) + }) + + DescribeTable("time to first token is against number of prompt tokens with std", + func(prefillOverhead int, prefillTimePerToken int, stdDev int, nTokens int, nCachedTokens int) { + simulator.config.TimeToFirstToken = 0 + simulator.config.PrefillOverhead = prefillOverhead + simulator.config.PrefillTimePerToken = prefillTimePerToken + simulator.config.PrefillTimeStdDev = stdDev + + ttft := simulator.getWaitTimeToFirstToken(nTokens, nCachedTokens, false) + + expectedTTFT := prefillOverhead + prefillTimePerToken*(nTokens-nCachedTokens) + Expect(ttft).To(BeNumerically(">=", int(float64(expectedTTFT)*0.3))) + Expect(ttft).To(BeNumerically("<=", int(float64(expectedTTFT)*1.7))) + }, + func(prefillOverhead int, prefillTimePerToken, stdDev int, nTokens int, nCachedTokens int) string { + return fmt.Sprintf("prefillOverhead: %d, prefillTimePerToken: %d, stdDev: %d, nTokens: %d nCachedTokens: %d", + prefillOverhead, prefillTimePerToken, stdDev, nTokens, nCachedTokens) + }, + Entry("single token", 100, 50, 10, 1, 0), + Entry("single token big std", 100, 50, 70, 1, 0), + Entry("stddev is 0", 100, 50, 0, 1, 0), + Entry("medium overhead, 512 tokens", 200, 1000, 150, 512, 0), + Entry("large overhead, 1024 tokens", 2000, 3000, 800, 1024, 0), + Entry("very long prompt", 150, 200, 70, 20000, 0), + Entry("medium overhead, 512 tokens, 256 cached", 200, 1000, 150, 512, 256), + Entry("large overhead, 1024 tokens, 1008 cached", 2000, 3000, 800, 1024, 1008), + Entry("very long prompt, 1024 cached", 150, 200, 70, 20000, 1024), + ) + + DescribeTable("time to first token is against number of prompt tokens", + func(prefillOverhead int, prefillTimePerToken int, nTokens int, nCachedTokens int) { + simulator.config.TimeToFirstToken = 0 + simulator.config.PrefillOverhead = prefillOverhead + simulator.config.PrefillTimePerToken = prefillTimePerToken + simulator.config.PrefillTimeStdDev = 0 + + ttft := simulator.getWaitTimeToFirstToken(nTokens, nCachedTokens, false) + expectedTTFT := prefillOverhead + prefillTimePerToken*(nTokens-nCachedTokens) + Expect(ttft).To(Equal(expectedTTFT)) + }, + func(prefillOverhead int, prefillTimePerToken, nTokens int, nCachedTokens int) string { + return fmt.Sprintf("prefillOverhead: %d, prefillTimePerToken: %d, nTokens: %d nCachedTokens: %d", + prefillOverhead, prefillTimePerToken, nTokens, nCachedTokens) + }, + Entry("single token", 100, 50, 1, 0), + Entry("medium overhead, 512 tokens", 200, 1000, 512, 0), + Entry("large overhead, 1024 tokens", 2000, 3000, 1024, 0), + Entry("very long prompt", 150, 200, 20000, 0), + Entry("medium overhead, 512 tokens, 256 cached", 200, 1000, 512, 256), + Entry("large overhead, 1024 tokens, 128 cached", 2000, 3000, 1024, 128), + Entry("very long prompt, 1024 cached", 150, 200, 20000, 1024), + ) + + It("when not 0, ignore ", func() { + simulator.config.KVCacheTransferLatency = 200 + simulator.config.KVCacheTransferLatencyStdDev = 0 + + simulator.config.KVCacheTransferTimePerToken = 100 + simulator.config.KVCacheTransferTimeStdDev = 0 + + ttft := simulator.getWaitTimeToFirstToken(128, 0, true) + Expect(ttft).To(BeNumerically("==", 200)) + }) + + It("when is 0, and is not 0, use ", func() { + simulator.config.KVCacheTransferLatency = 0 + simulator.config.KVCacheTransferLatencyStdDev = 0 + + simulator.config.KVCacheTransferTimePerToken = 100 + simulator.config.KVCacheTransferTimeStdDev = 0 + + ttft := simulator.getWaitTimeToFirstToken(128, 0, true) + Expect(ttft).To(BeNumerically("==", 12800)) + }) + + DescribeTable("kv cache transfer time against number of prompt tokens", + func(kvCacheTransTPT int, stddev int, nTokens int) { + simulator.config.TimeToFirstToken = 0 + simulator.config.PrefillOverhead = 1 + simulator.config.KVCacheTransferTimePerToken = kvCacheTransTPT + simulator.config.KVCacheTransferTimeStdDev = stddev + + ttft := simulator.getWaitTimeToFirstToken(nTokens, 0, true) + + expectedTTFT := kvCacheTransTPT * nTokens + Expect(ttft).To(BeNumerically(">=", int(float64(expectedTTFT)*0.3))) + Expect(ttft).To(BeNumerically("<=", int(float64(expectedTTFT)*1.7))) + + }, + func(kvCacheTransferTimePerToken int, stddev int, nTokens int) string { + return fmt.Sprintf("kvCacheTransferTimePerToken: %d stddev: %d nTokens: %d", + kvCacheTransferTimePerToken, stddev, nTokens) + }, + Entry("single token", 100, 70, 1), + Entry("stddev is 0", 100, 0, 1), + Entry("medium overhead, 512 tokens", 200, 150, 512), + Entry("large overhead, 1024 tokens", 2000, 1800, 1024), + Entry("very long prompt", 150, 100, 20000), + ) + + It("when time-factor-under-load is 1, the time to first token should be equal to time-to-first-token", func() { + simulator.config.TimeToFirstToken = 42 + simulator.config.TimeToFirstTokenStdDev = 0 + simulator.config.TimeFactorUnderLoad = 1.0 + + simulator.runReqChan <- 100 + + ttft := simulator.getWaitTimeToFirstToken(128, 0, false) + Expect(ttft).To(Equal(42)) + }) + + It("when time-factor-under-load is > 1, but max-num-seqs is 1, the factor will not take effect", func() { + simulator.config.TimeToFirstToken = 42 + simulator.config.TimeToFirstTokenStdDev = 0 + simulator.config.TimeFactorUnderLoad = 100.0 + simulator.config.MaxNumSeqs = 1 + + for len(simulator.runReqChan) > 0 { + <-simulator.runReqChan + } + + simulator.runReqChan <- 1 + + ttft := simulator.getWaitTimeToFirstToken(128, 0, false) + Expect(ttft).To(Equal(42)) + }) + + DescribeTable("when time-factor-under-load is > 1, and the sim is fully loaded, the time to first token should be time-factor-under-load * time-to-first-token", + func(timeFactorUnderLoad float64, maxNumOfReq int) { + simulator.config.TimeToFirstToken = 42 + simulator.config.TimeToFirstTokenStdDev = 0 + simulator.config.TimeFactorUnderLoad = timeFactorUnderLoad + simulator.config.MaxNumSeqs = maxNumOfReq + simulator.nRunningReqs = int64(maxNumOfReq) + + ttft := simulator.getWaitTimeToFirstToken(128, 0, false) + Expect(ttft).To(Equal(int(float64(42) * timeFactorUnderLoad))) + + }, + func(timeFactorUnderLoad float64, maxNumOfReq int64) string { + return fmt.Sprintf("timeFactorUnderLoad: %f maxNumOfReq: %d", + timeFactorUnderLoad, maxNumOfReq) + }, + + Entry("factor: 1.5", 1.5, 70), + Entry("factor: 2.0", 2.0, 2), + Entry("factor: 100.0", 100.0, 150), + Entry("factor: 20000.0", 20000.0, 310), + ) + + DescribeTable("when time-factor-under-load is > 1, and the sim is partially loaded, the time to first token should be linear interpolation between time-to-first-token and time-factor-under-load * time-to-first-token", + func(timeFactorUnderLoad float64, maxNumOfReq int, nCurrNumOfReq int) { + simulator.config.TimeToFirstToken = 42 + simulator.config.TimeToFirstTokenStdDev = 0 + simulator.config.TimeFactorUnderLoad = timeFactorUnderLoad + simulator.config.MaxNumSeqs = maxNumOfReq + simulator.nRunningReqs = int64(nCurrNumOfReq) + + ttft := simulator.getWaitTimeToFirstToken(128, 0, false) + max := timeFactorUnderLoad * float64(42) + Expect(ttft).To(BeNumerically(">=", 42)) + Expect(ttft).To(BeNumerically("<=", max)) + + }, + func(timeFactorUnderLoad float64, maxNumOfReq int, nCurrNumOfReq int) string { + return fmt.Sprintf("timeFactorUnderLoad: %f maxNumOfReq: %d nCurrNumOfReq: %d", + timeFactorUnderLoad, maxNumOfReq, nCurrNumOfReq) + }, + + Entry("factor: 1.5", 1.5, 70, 35), + Entry("factor: 2.0", 2.0, 2, 1), + Entry("factor: 100.0", 100.0, 150, 75), + Entry("factor: 20000.0", 20000.0, 310, 155), + ) + + It("when TimeFactorUnderLoad is 1.0, calcLoadFactor should give 1", func() { + simulator.config.TimeFactorUnderLoad = 1.0 + simulator.config.MaxNumSeqs = 11 + simulator.nRunningReqs = 3 + + factor := simulator.getCurrLoadFactor() + Expect(factor).To(BeNumerically("==", 1.0)) + }) + + It("when TimeFactorUnderLoad is > 1.0, and sim is fully loaded, calcLoadFactor should give TimeFactorUnderLoad", func() { + simulator.config.TimeFactorUnderLoad = 2.0 + simulator.config.MaxNumSeqs = 11 + simulator.nRunningReqs = 11 + + factor := simulator.getCurrLoadFactor() + Expect(factor).To(BeNumerically("==", simulator.config.TimeFactorUnderLoad)) + + }) + + It("when TimeFactorUnderLoad is > 1.0, and sim is partially loaded, calcLoadFactor should give a value between 1 and TimeFactorUnderLoad", func() { + simulator.config.TimeFactorUnderLoad = 2.0 + simulator.config.MaxNumSeqs = 11 + simulator.nRunningReqs = 6 + + factor := simulator.getCurrLoadFactor() + Expect(factor).To(BeNumerically(">", 1.0)) + Expect(factor).To(BeNumerically("<", simulator.config.TimeFactorUnderLoad)) + }) +}) diff --git a/pkg/llm-d-inference-sim/lora_test.go b/pkg/llm-d-inference-sim/lora_test.go index 7ec37d0d..624684f2 100644 --- a/pkg/llm-d-inference-sim/lora_test.go +++ b/pkg/llm-d-inference-sim/lora_test.go @@ -40,17 +40,8 @@ var _ = Describe("LoRAs", func() { "{\"name\":\"lora4\",\"path\":\"/path/to/lora4\"}"}, nil) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - // Request to lora3 - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - Model: "lora3", - } + openaiclient, params := getOpenAIClentAndChatParams(client, "lora3", userMessage, false) resp, err := openaiclient.Chat.Completions.New(ctx, params) Expect(err).ToNot(HaveOccurred()) diff --git a/pkg/llm-d-inference-sim/metrics_test.go b/pkg/llm-d-inference-sim/metrics_test.go index 2fa385ba..744f54e1 100644 --- a/pkg/llm-d-inference-sim/metrics_test.go +++ b/pkg/llm-d-inference-sim/metrics_test.go @@ -73,16 +73,7 @@ var _ = Describe("Simulator metrics", Ordered, func() { client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - Model: modelName, - } + openaiclient, params := getOpenAIClentAndChatParams(client, modelName, userMessage, false) var wg sync.WaitGroup wg.Add(1) diff --git a/pkg/llm-d-inference-sim/seed_test.go b/pkg/llm-d-inference-sim/seed_test.go index 505b4938..1f48deae 100644 --- a/pkg/llm-d-inference-sim/seed_test.go +++ b/pkg/llm-d-inference-sim/seed_test.go @@ -23,7 +23,6 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/openai/openai-go" - "github.com/openai/openai-go/option" ) var _ = Describe("Simulator with seed", func() { @@ -36,17 +35,8 @@ var _ = Describe("Simulator with seed", func() { []string{"cmd", "--model", model, "--mode", common.ModeRandom, "--seed", "100"}, nil) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - params := openai.CompletionNewParams{ - Prompt: openai.CompletionNewParamsPromptUnion{ - OfString: openai.String(userMessage), - }, - Model: openai.CompletionNewParamsModel(model), - MaxTokens: openai.Int(10), - } - + openaiclient, params := getOpenAIClentAndCompletionParams(client, model, userMessage, false) + params.MaxTokens = openai.Int(10) resp, err := openaiclient.Completions.New(ctx, params) Expect(err).NotTo(HaveOccurred()) Expect(resp.Choices).ShouldNot(BeEmpty()) @@ -77,16 +67,7 @@ var _ = Describe("Simulator with seed", func() { client, err := startServer(ctx, common.ModeRandom) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - params := openai.CompletionNewParams{ - Prompt: openai.CompletionNewParamsPromptUnion{ - OfString: openai.String(userMessage), - }, - Model: openai.CompletionNewParamsModel(model), - } - + openaiclient, params := getOpenAIClentAndCompletionParams(client, model, userMessage, false) resp, err := openaiclient.Completions.New(ctx, params) Expect(err).NotTo(HaveOccurred()) Expect(resp.Choices).ShouldNot(BeEmpty()) diff --git a/pkg/llm-d-inference-sim/server.go b/pkg/llm-d-inference-sim/server.go new file mode 100644 index 00000000..5fb77d5e --- /dev/null +++ b/pkg/llm-d-inference-sim/server.go @@ -0,0 +1,318 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Package vllmsim implements the vLLM simulator. +package llmdinferencesim + +import ( + "context" + "encoding/json" + "fmt" + "net" + + "github.com/buaazp/fasthttprouter" + "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttpadaptor" + + "github.com/llm-d/llm-d-inference-sim/pkg/common" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" + vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api" +) + +func (s *VllmSimulator) newListener() (net.Listener, error) { + s.logger.Info("Server starting", "port", s.config.Port) + listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", s.config.Port)) + if err != nil { + return nil, err + } + return listener, nil +} + +// startServer starts http server on port defined in command line +func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener) error { + r := fasthttprouter.New() + + // support completion APIs + r.POST("/v1/chat/completions", s.HandleChatCompletions) + r.POST("/v1/completions", s.HandleTextCompletions) + // supports /models API + r.GET("/v1/models", s.HandleModels) + // support load/unload of lora adapter + r.POST("/v1/load_lora_adapter", s.HandleLoadLora) + r.POST("/v1/unload_lora_adapter", s.HandleUnloadLora) + // supports /metrics prometheus API + r.GET("/metrics", fasthttpadaptor.NewFastHTTPHandler(promhttp.HandlerFor(s.registry, promhttp.HandlerOpts{}))) + // supports standard Kubernetes health and readiness checks + r.GET("/health", s.HandleHealth) + r.GET("/ready", s.HandleReady) + r.POST("/tokenize", s.HandleTokenize) + + server := fasthttp.Server{ + ErrorHandler: s.HandleError, + Handler: r.Handler, + Logger: s, + } + + // Start server in a goroutine + serverErr := make(chan error, 1) + go func() { + s.logger.Info("HTTP server starting") + serverErr <- server.Serve(listener) + }() + + // Wait for either context cancellation or server error + select { + case <-ctx.Done(): + s.logger.Info("Shutdown signal received, shutting down HTTP server gracefully") + + // Gracefully shutdown the server + if err := server.Shutdown(); err != nil { + s.logger.Error(err, "Error during server shutdown") + return err + } + + s.logger.Info("HTTP server stopped") + return nil + + case err := <-serverErr: + if err != nil { + s.logger.Error(err, "HTTP server failed") + } + return err + } +} + +// readRequest reads and parses data from the body of the given request according the type defined by isChatCompletion +func (s *VllmSimulator) readRequest(ctx *fasthttp.RequestCtx, isChatCompletion bool) (openaiserverapi.CompletionRequest, error) { + requestID := common.GenerateUUIDString() + + if isChatCompletion { + var req openaiserverapi.ChatCompletionRequest + + err := json.Unmarshal(ctx.Request.Body(), &req) + if err != nil { + s.logger.Error(err, "failed to unmarshal request body") + return nil, err + } + + for _, tool := range req.Tools { + toolJson, err := json.Marshal(tool.Function) + if err != nil { + s.logger.Error(err, "failed to marshal request tools") + return nil, err + } + err = s.toolsValidator.ValidateTool(toolJson) + if err != nil { + s.logger.Error(err, "tool validation failed") + return nil, err + } + } + req.RequestID = requestID + + return &req, nil + } + + var req openaiserverapi.TextCompletionRequest + err := json.Unmarshal(ctx.Request.Body(), &req) + + req.RequestID = requestID + + return &req, err +} + +// HandleChatCompletions http handler for /v1/chat/completions +func (s *VllmSimulator) HandleChatCompletions(ctx *fasthttp.RequestCtx) { + s.logger.Info("chat completion request received") + s.handleCompletions(ctx, true) +} + +// HandleTextCompletions http handler for /v1/completions +func (s *VllmSimulator) HandleTextCompletions(ctx *fasthttp.RequestCtx) { + s.logger.Info("completion request received") + s.handleCompletions(ctx, false) +} + +// readTokenizeRequest reads and parses data from the body of the given request +func (s *VllmSimulator) readTokenizeRequest(ctx *fasthttp.RequestCtx) (*vllmapi.TokenizeRequest, error) { + var tokenizeReq vllmapi.TokenizeRequest + if err := json.Unmarshal(ctx.Request.Body(), &tokenizeReq); err != nil { + s.logger.Error(err, "failed to unmarshal tokenize request body") + return nil, err + } + return &tokenizeReq, nil +} + +// HandleTokenize http handler for /tokenize +func (s *VllmSimulator) HandleTokenize(ctx *fasthttp.RequestCtx) { + s.logger.Info("tokenize request received") + req, err := s.readTokenizeRequest(ctx) + if err != nil { + s.logger.Error(err, "failed to read and parse tokenize request body") + ctx.Error("Failed to read and parse tokenize request body, "+err.Error(), fasthttp.StatusBadRequest) + return + } + + // Check that the request has only one input to tokenize + if req.Prompt != "" && req.Messages != nil { + s.sendCompletionError(ctx, openaiserverapi.NewCompletionError("both prompt and messages fields in tokenize request", + fasthttp.StatusBadRequest, nil), false) + return + } + // Model is optional, if not set, the model from the configuration will be used + model := req.Model + if model == "" { + model = s.config.Model + } + + tokens, _, err := s.tokenizer.Encode(req.GetPrompt(), model) + if err != nil { + s.logger.Error(err, "failed to tokenize") + ctx.Error("Failed to tokenize, "+err.Error(), fasthttp.StatusInternalServerError) + return + } + resp := vllmapi.TokenizeResponse{ + Count: len(tokens), + Tokens: tokens, + MaxModelLen: s.config.MaxModelLen, + } + data, err := json.Marshal(resp) + if err != nil { + ctx.Error("Response body creation failed, "+err.Error(), fasthttp.StatusInternalServerError) + return + } + ctx.Response.Header.SetContentType("application/json") + ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) + ctx.Response.SetBody(data) +} + +func (s *VllmSimulator) HandleLoadLora(ctx *fasthttp.RequestCtx) { + s.logger.Info("load lora request received") + s.loadLora(ctx) +} + +func (s *VllmSimulator) HandleUnloadLora(ctx *fasthttp.RequestCtx) { + s.logger.Info("unload lora request received") + s.unloadLora(ctx) +} + +func (s *VllmSimulator) validateRequest(req openaiserverapi.CompletionRequest) (string, int) { + if !s.isValidModel(req.GetModel()) { + return fmt.Sprintf("The model `%s` does not exist.", req.GetModel()), fasthttp.StatusNotFound + } + + if req.GetMaxCompletionTokens() != nil && *req.GetMaxCompletionTokens() <= 0 { + return "Max completion tokens and max tokens should be positive", fasthttp.StatusBadRequest + } + + if req.IsDoRemoteDecode() && req.IsStream() { + return "Prefill does not support streaming", fasthttp.StatusBadRequest + } + + if req.GetIgnoreEOS() && req.GetMaxCompletionTokens() == nil { + return "Ignore_eos is true but max_completion_tokens (or max_tokens) is not set", fasthttp.StatusBadRequest + } + + // Validate context window constraints + promptTokens := req.GetNumberOfPromptTokens() + completionTokens := req.GetMaxCompletionTokens() + isValid, actualCompletionTokens, totalTokens := common.ValidateContextWindow(promptTokens, completionTokens, s.config.MaxModelLen) + if !isValid { + message := fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", + s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens) + return message, fasthttp.StatusBadRequest + } + return "", fasthttp.StatusOK +} + +// sendCompletionResponse sends a completion response +func (s *VllmSimulator) sendCompletionResponse(ctx *fasthttp.RequestCtx, resp openaiserverapi.CompletionResponse) { + data, err := json.Marshal(resp) + if err != nil { + ctx.Error("Response body creation failed, "+err.Error(), fasthttp.StatusInternalServerError) + return + } + ctx.Response.Header.SetContentType("application/json") + ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) + // Add pod and namespace information to response headers for testing/debugging + if s.pod != "" { + ctx.Response.Header.Add(podHeader, s.pod) + } + if s.namespace != "" { + ctx.Response.Header.Add(namespaceHeader, s.namespace) + } + ctx.Response.SetBody(data) +} + +// sendCompletionError sends an error response for the current completion request +// isInjected indicates if this is an injected failure for logging purposes +func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, + compErr openaiserverapi.CompletionError, isInjected bool) { + if isInjected { + s.logger.Info("Injecting failure", "type", compErr.Type, "message", compErr.Message) + } else { + s.logger.Error(nil, compErr.Message) + } + + errorResp := openaiserverapi.ErrorResponse{ + Error: compErr, + } + + data, err := json.Marshal(errorResp) + if err != nil { + ctx.Error(err.Error(), fasthttp.StatusInternalServerError) + } else { + ctx.SetContentType("application/json") + ctx.SetStatusCode(compErr.Code) + ctx.SetBody(data) + } +} + +// HandleModels handles /v1/models request according the data stored in the simulator +func (s *VllmSimulator) HandleModels(ctx *fasthttp.RequestCtx) { + modelsResp := s.createModelsResponse() + + data, err := json.Marshal(modelsResp) + if err != nil { + s.logger.Error(err, "Failed to marshal models response") + ctx.Error("Failed to marshal models response, "+err.Error(), fasthttp.StatusInternalServerError) + return + } + + ctx.Response.Header.SetContentType("application/json") + ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) + ctx.Response.SetBody(data) +} + +func (s *VllmSimulator) HandleError(_ *fasthttp.RequestCtx, err error) { + s.logger.Error(err, "VLLM server error") +} + +// HandleHealth http handler for /health +func (s *VllmSimulator) HandleHealth(ctx *fasthttp.RequestCtx) { + s.logger.V(4).Info("health request received") + ctx.Response.Header.SetContentType("application/json") + ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) + ctx.Response.SetBody([]byte("{}")) +} + +// HandleReady http handler for /ready +func (s *VllmSimulator) HandleReady(ctx *fasthttp.RequestCtx) { + s.logger.V(4).Info("readiness request received") + ctx.Response.Header.SetContentType("application/json") + ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) + ctx.Response.SetBody([]byte("{}")) +} diff --git a/pkg/llm-d-inference-sim/server_test.go b/pkg/llm-d-inference-sim/server_test.go new file mode 100644 index 00000000..ee1dfd7c --- /dev/null +++ b/pkg/llm-d-inference-sim/server_test.go @@ -0,0 +1,119 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llmdinferencesim + +import ( + "context" + "encoding/json" + "io" + "net/http" + "os" + "strings" + + "github.com/llm-d/llm-d-inference-sim/pkg/common" + vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Server", func() { + It("Should respond to /health", func() { + ctx := context.TODO() + client, err := startServer(ctx, common.ModeRandom) + Expect(err).NotTo(HaveOccurred()) + + resp, err := client.Get("http://localhost/health") + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + }) + + It("Should respond to /ready", func() { + ctx := context.TODO() + client, err := startServer(ctx, common.ModeRandom) + Expect(err).NotTo(HaveOccurred()) + + resp, err := client.Get("http://localhost/ready") + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + }) + + Context("tokenize", Ordered, func() { + tmpDir := "./tests-tmp/" + AfterAll(func() { + err := os.RemoveAll(tmpDir) + Expect(err).NotTo(HaveOccurred()) + }) + + It("Should return correct response to /tokenize chat", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom, + "--tokenizers-cache-dir", tmpDir, "--max-model-len", "2048"} + client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) + Expect(err).NotTo(HaveOccurred()) + + reqBody := `{ + "messages": [{"role": "user", "content": "This is a test"}], + "model": "Qwen/Qwen2-0.5B" + }` + resp, err := client.Post("http://localhost/tokenize", "application/json", strings.NewReader(reqBody)) + Expect(err).NotTo(HaveOccurred()) + defer func() { + err := resp.Body.Close() + Expect(err).NotTo(HaveOccurred()) + }() + + body, err := io.ReadAll(resp.Body) + Expect(err).NotTo(HaveOccurred()) + + var tokenizeResp vllmapi.TokenizeResponse + err = json.Unmarshal(body, &tokenizeResp) + Expect(err).NotTo(HaveOccurred()) + Expect(tokenizeResp.Count).To(Equal(4)) + Expect(tokenizeResp.Tokens).To(HaveLen(4)) + Expect(tokenizeResp.MaxModelLen).To(Equal(2048)) + }) + + It("Should return correct response to /tokenize text", func() { + ctx := context.TODO() + args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom, + "--tokenizers-cache-dir", tmpDir, "--max-model-len", "2048"} + client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) + Expect(err).NotTo(HaveOccurred()) + + reqBody := `{ + "prompt": "This is a test", + "model": "Qwen/Qwen2-0.5B" + }` + resp, err := client.Post("http://localhost/tokenize", "application/json", strings.NewReader(reqBody)) + Expect(err).NotTo(HaveOccurred()) + defer func() { + err := resp.Body.Close() + Expect(err).NotTo(HaveOccurred()) + }() + + body, err := io.ReadAll(resp.Body) + Expect(err).NotTo(HaveOccurred()) + + var tokenizeResp vllmapi.TokenizeResponse + err = json.Unmarshal(body, &tokenizeResp) + Expect(err).NotTo(HaveOccurred()) + Expect(tokenizeResp.Count).To(Equal(4)) + Expect(tokenizeResp.Tokens).To(HaveLen(4)) + Expect(tokenizeResp.MaxModelLen).To(Equal(2048)) + }) + }) +}) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 32d76ee7..ab55fea2 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -19,20 +19,15 @@ package llmdinferencesim import ( "context" - "encoding/json" "fmt" - "net" "os" "strings" "sync" "time" - "github.com/buaazp/fasthttprouter" "github.com/go-logr/logr" "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/valyala/fasthttp" - "github.com/valyala/fasthttp/fasthttpadaptor" "golang.org/x/sync/errgroup" "k8s.io/klog/v2" @@ -234,243 +229,11 @@ func (s *VllmSimulator) startSim(ctx context.Context) error { return s.startServer(ctx, listener) } -func (s *VllmSimulator) newListener() (net.Listener, error) { - s.logger.Info("Server starting", "port", s.config.Port) - listener, err := net.Listen("tcp4", fmt.Sprintf(":%d", s.config.Port)) - if err != nil { - return nil, err - } - return listener, nil -} - -// startServer starts http server on port defined in command line -func (s *VllmSimulator) startServer(ctx context.Context, listener net.Listener) error { - r := fasthttprouter.New() - - // support completion APIs - r.POST("/v1/chat/completions", s.HandleChatCompletions) - r.POST("/v1/completions", s.HandleTextCompletions) - // supports /models API - r.GET("/v1/models", s.HandleModels) - // support load/unload of lora adapter - r.POST("/v1/load_lora_adapter", s.HandleLoadLora) - r.POST("/v1/unload_lora_adapter", s.HandleUnloadLora) - // supports /metrics prometheus API - r.GET("/metrics", fasthttpadaptor.NewFastHTTPHandler(promhttp.HandlerFor(s.registry, promhttp.HandlerOpts{}))) - // supports standard Kubernetes health and readiness checks - r.GET("/health", s.HandleHealth) - r.GET("/ready", s.HandleReady) - r.POST("/tokenize", s.HandleTokenize) - - server := fasthttp.Server{ - ErrorHandler: s.HandleError, - Handler: r.Handler, - Logger: s, - } - - // Start server in a goroutine - serverErr := make(chan error, 1) - go func() { - s.logger.Info("HTTP server starting") - serverErr <- server.Serve(listener) - }() - - // Wait for either context cancellation or server error - select { - case <-ctx.Done(): - s.logger.Info("Shutdown signal received, shutting down HTTP server gracefully") - - // Gracefully shutdown the server - if err := server.Shutdown(); err != nil { - s.logger.Error(err, "Error during server shutdown") - return err - } - - s.logger.Info("HTTP server stopped") - return nil - - case err := <-serverErr: - if err != nil { - s.logger.Error(err, "HTTP server failed") - } - return err - } -} - // Print prints to a log, implementation of fasthttp.Logger func (s *VllmSimulator) Printf(format string, args ...interface{}) { s.logger.Info("Server error", "msg", fmt.Sprintf(format, args...)) } -// readRequest reads and parses data from the body of the given request according the type defined by isChatCompletion -func (s *VllmSimulator) readRequest(ctx *fasthttp.RequestCtx, isChatCompletion bool) (openaiserverapi.CompletionRequest, error) { - requestID := common.GenerateUUIDString() - - if isChatCompletion { - var req openaiserverapi.ChatCompletionRequest - - err := json.Unmarshal(ctx.Request.Body(), &req) - if err != nil { - s.logger.Error(err, "failed to unmarshal request body") - return nil, err - } - - for _, tool := range req.Tools { - toolJson, err := json.Marshal(tool.Function) - if err != nil { - s.logger.Error(err, "failed to marshal request tools") - return nil, err - } - err = s.toolsValidator.ValidateTool(toolJson) - if err != nil { - s.logger.Error(err, "tool validation failed") - return nil, err - } - } - req.RequestID = requestID - - return &req, nil - } - - var req openaiserverapi.TextCompletionRequest - err := json.Unmarshal(ctx.Request.Body(), &req) - - req.RequestID = requestID - - return &req, err -} - -// HandleChatCompletions http handler for /v1/chat/completions -func (s *VllmSimulator) HandleChatCompletions(ctx *fasthttp.RequestCtx) { - s.logger.Info("chat completion request received") - s.handleCompletions(ctx, true) -} - -// HandleTextCompletions http handler for /v1/completions -func (s *VllmSimulator) HandleTextCompletions(ctx *fasthttp.RequestCtx) { - s.logger.Info("completion request received") - s.handleCompletions(ctx, false) -} - -// readTokenizeRequest reads and parses data from the body of the given request -func (s *VllmSimulator) readTokenizeRequest(ctx *fasthttp.RequestCtx) (*vllmapi.TokenizeRequest, error) { - var tokenizeReq vllmapi.TokenizeRequest - if err := json.Unmarshal(ctx.Request.Body(), &tokenizeReq); err != nil { - s.logger.Error(err, "failed to unmarshal tokenize request body") - return nil, err - } - return &tokenizeReq, nil -} - -// HandleTokenize http handler for /tokenize -func (s *VllmSimulator) HandleTokenize(ctx *fasthttp.RequestCtx) { - s.logger.Info("tokenize request received") - req, err := s.readTokenizeRequest(ctx) - if err != nil { - s.logger.Error(err, "failed to read and parse tokenize request body") - ctx.Error("Failed to read and parse tokenize request body, "+err.Error(), fasthttp.StatusBadRequest) - return - } - - // Check that the request has only one input to tokenize - if req.Prompt != "" && req.Messages != nil { - s.sendCompletionError(ctx, openaiserverapi.NewCompletionError("both prompt and messages fields in tokenize request", - fasthttp.StatusBadRequest, nil), false) - return - } - // Model is optional, if not set, the model from the configuration will be used - model := req.Model - if model == "" { - model = s.config.Model - } - - tokens, _, err := s.tokenizer.Encode(req.GetPrompt(), model) - if err != nil { - s.logger.Error(err, "failed to tokenize") - ctx.Error("Failed to tokenize, "+err.Error(), fasthttp.StatusInternalServerError) - return - } - resp := vllmapi.TokenizeResponse{ - Count: len(tokens), - Tokens: tokens, - MaxModelLen: s.config.MaxModelLen, - } - data, err := json.Marshal(resp) - if err != nil { - ctx.Error("Response body creation failed, "+err.Error(), fasthttp.StatusInternalServerError) - return - } - ctx.Response.Header.SetContentType("application/json") - ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) - ctx.Response.SetBody(data) -} - -func (s *VllmSimulator) HandleLoadLora(ctx *fasthttp.RequestCtx) { - s.logger.Info("load lora request received") - s.loadLora(ctx) -} - -func (s *VllmSimulator) HandleUnloadLora(ctx *fasthttp.RequestCtx) { - s.logger.Info("unload lora request received") - s.unloadLora(ctx) -} - -func (s *VllmSimulator) validateRequest(req openaiserverapi.CompletionRequest) (string, int) { - if !s.isValidModel(req.GetModel()) { - return fmt.Sprintf("The model `%s` does not exist.", req.GetModel()), fasthttp.StatusNotFound - } - - if req.GetMaxCompletionTokens() != nil && *req.GetMaxCompletionTokens() <= 0 { - return "Max completion tokens and max tokens should be positive", fasthttp.StatusBadRequest - } - - if req.IsDoRemoteDecode() && req.IsStream() { - return "Prefill does not support streaming", fasthttp.StatusBadRequest - } - - if req.GetIgnoreEOS() && req.GetMaxCompletionTokens() == nil { - return "Ignore_eos is true but max_completion_tokens (or max_tokens) is not set", fasthttp.StatusBadRequest - } - - // Validate context window constraints - promptTokens := req.GetNumberOfPromptTokens() - completionTokens := req.GetMaxCompletionTokens() - isValid, actualCompletionTokens, totalTokens := common.ValidateContextWindow(promptTokens, completionTokens, s.config.MaxModelLen) - if !isValid { - message := fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", - s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens) - return message, fasthttp.StatusBadRequest - } - return "", fasthttp.StatusOK -} - -// isValidModel checks if the given model is the base model or one of "loaded" LoRAs -func (s *VllmSimulator) isValidModel(model string) bool { - for _, name := range s.config.ServedModelNames { - if model == name { - return true - } - } - for _, lora := range s.getLoras() { - if model == lora { - return true - } - } - - return false -} - -// isLora returns true if the given model name is one of loaded LoRAs -func (s *VllmSimulator) isLora(model string) bool { - for _, lora := range s.getLoras() { - if model == lora { - return true - } - } - - return false -} - // handleCompletions general completion requests handler, support both text and chat completion APIs func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatCompletion bool) { // Check if we should inject a failure @@ -623,50 +386,6 @@ func (s *VllmSimulator) responseSentCallback(model string, isChatCompletion bool } } -// sendCompletionError sends an error response for the current completion request -// isInjected indicates if this is an injected failure for logging purposes -func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, - compErr openaiserverapi.CompletionError, isInjected bool) { - if isInjected { - s.logger.Info("Injecting failure", "type", compErr.Type, "message", compErr.Message) - } else { - s.logger.Error(nil, compErr.Message) - } - - errorResp := openaiserverapi.ErrorResponse{ - Error: compErr, - } - - data, err := json.Marshal(errorResp) - if err != nil { - ctx.Error(err.Error(), fasthttp.StatusInternalServerError) - } else { - ctx.SetContentType("application/json") - ctx.SetStatusCode(compErr.Code) - ctx.SetBody(data) - } -} - -// HandleModels handles /v1/models request according the data stored in the simulator -func (s *VllmSimulator) HandleModels(ctx *fasthttp.RequestCtx) { - modelsResp := s.createModelsResponse() - - data, err := json.Marshal(modelsResp) - if err != nil { - s.logger.Error(err, "Failed to marshal models response") - ctx.Error("Failed to marshal models response, "+err.Error(), fasthttp.StatusInternalServerError) - return - } - - ctx.Response.Header.SetContentType("application/json") - ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) - ctx.Response.SetBody(data) -} - -func (s *VllmSimulator) HandleError(_ *fasthttp.RequestCtx, err error) { - s.logger.Error(err, "VLLM server error") -} - // createCompletionResponse creates the response for completion requests, supports both completion request types (text and chat) // as defined by isChatCompletion // respTokens - tokenized content to be sent in the response @@ -733,61 +452,20 @@ func (s *VllmSimulator) sendResponse(reqCtx *openaiserverapi.CompletionReqCtx, r resp := s.createCompletionResponse(reqCtx.IsChatCompletion, respTokens, toolCalls, &finishReason, usageData, modelName, reqCtx.CompletionReq.IsDoRemoteDecode()) - ctx := reqCtx.HTTPReqCtx - data, err := json.Marshal(resp) - if err != nil { - ctx.Error("Response body creation failed, "+err.Error(), fasthttp.StatusInternalServerError) - return - } - // calculate how long to wait before returning the response, time is based on number of tokens nCachedPromptTokens := reqCtx.CompletionReq.GetNumberOfCachedPromptTokens() - ttft := s.getTimeToFirstToken(usageData.PromptTokens, nCachedPromptTokens, reqCtx.CompletionReq.IsDoRemotePrefill()) + ttft := s.getWaitTimeToFirstToken(usageData.PromptTokens, nCachedPromptTokens, reqCtx.CompletionReq.IsDoRemotePrefill()) time.Sleep(time.Duration(ttft) * time.Millisecond) for range usageData.CompletionTokens - 1 { perTokenLatency := s.getInterTokenLatency() time.Sleep(time.Duration(perTokenLatency) * time.Millisecond) } - ctx.Response.Header.SetContentType("application/json") - ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) - // Add pod and namespace information to response headers for testing/debugging - if s.pod != "" { - ctx.Response.Header.Add(podHeader, s.pod) - } - if s.namespace != "" { - ctx.Response.Header.Add(namespaceHeader, s.namespace) - } - ctx.Response.SetBody(data) + s.sendCompletionResponse(reqCtx.HTTPReqCtx, resp) s.responseSentCallback(modelName, reqCtx.IsChatCompletion, reqCtx.CompletionReq.GetRequestID()) } -// returns time to first token based on the current request's doRemotePrefill -func (s *VllmSimulator) getTimeToFirstToken(nPromptTokens int, nCachedPromptTokens int, doRemotePrefill bool) int { - if doRemotePrefill { - if s.config.KVCacheTransferLatency == 0 && s.config.KVCacheTransferLatencyStdDev == 0 { - // is disaggregated PD and ttft is calculated using number of prompt tokens - kvCacheTransT := s.config.KVCacheTransferTimePerToken * nPromptTokens - return common.RandomNorm(kvCacheTransT, s.config.KVCacheTransferTimeStdDev) - } - // is disaggregated PD and *not* using number of prompt tokens - return common.RandomNorm(s.config.KVCacheTransferLatency, s.config.KVCacheTransferLatencyStdDev) - } - if s.config.TimeToFirstToken == 0 && s.config.TimeToFirstTokenStdDev == 0 { - // is aggregated PD and ttft is calculated using number of prompt tokens that are not in kv cache - prefillTime := s.GetPrefillOverhead() + (nPromptTokens-nCachedPromptTokens)*s.GetPrefillTimePerToken() - return common.RandomNorm(prefillTime, s.config.PrefillTimeStdDev) - } - // is aggregated PD and *not* using number of prompt tokens - return common.RandomNorm(s.GetTimeToFirstToken(), s.config.TimeToFirstTokenStdDev) -} - -// returns inter token latency -func (s *VllmSimulator) getInterTokenLatency() int { - return common.RandomNorm(s.GetInterTokenLatency(), s.config.InterTokenLatencyStdDev) -} - // createModelsResponse creates and returns ModelResponse for the current state, returned array of models contains the base model + LoRA adapters if exist func (s *VllmSimulator) createModelsResponse() *vllmapi.ModelsResponse { modelsResp := vllmapi.ModelsResponse{Object: "list", Data: []vllmapi.ModelsResponseModelInfo{}} @@ -819,86 +497,3 @@ func (s *VllmSimulator) createModelsResponse() *vllmapi.ModelsResponse { return &modelsResp } - -// HandleHealth http handler for /health -func (s *VllmSimulator) HandleHealth(ctx *fasthttp.RequestCtx) { - s.logger.V(4).Info("health request received") - ctx.Response.Header.SetContentType("application/json") - ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) - ctx.Response.SetBody([]byte("{}")) -} - -// HandleReady http handler for /ready -func (s *VllmSimulator) HandleReady(ctx *fasthttp.RequestCtx) { - s.logger.V(4).Info("readiness request received") - ctx.Response.Header.SetContentType("application/json") - ctx.Response.Header.SetStatusCode(fasthttp.StatusOK) - ctx.Response.SetBody([]byte("{}")) -} - -// getDisplayedModelName returns the model name that must appear in API -// responses. LoRA adapters keep their explicit name, while all base-model -// requests are surfaced as the first alias from --served-model-name. -func (s *VllmSimulator) getDisplayedModelName(reqModel string) string { - if s.isLora(reqModel) { - return reqModel - } - return s.config.ServedModelNames[0] -} - -func (s *VllmSimulator) showConfig(dp bool) error { - cfgJSON, err := json.Marshal(s.config) - if err != nil { - return fmt.Errorf("failed to marshal configuration to JSON: %w", err) - } - - var m map[string]interface{} - err = json.Unmarshal(cfgJSON, &m) - if err != nil { - return fmt.Errorf("failed to unmarshal JSON to map: %w", err) - } - if dp { - // remove the port - delete(m, "port") - } - // clean LoraModulesString field - m["lora-modules"] = m["LoraModules"] - delete(m, "LoraModules") - delete(m, "LoraModulesString") - - // clean fake-metrics field - if field, ok := m["fake-metrics"].(map[string]interface{}); ok { - delete(field, "LorasString") - } - - // show in JSON - cfgJSON, err = json.MarshalIndent(m, "", " ") - if err != nil { - return fmt.Errorf("failed to marshal configuration to JSON: %w", err) - } - s.logger.Info("Configuration:", "", string(cfgJSON)) - return nil -} - -func (s *VllmSimulator) getCurrFactor() float64 { - if s.config.MaxNumSeqs <= 1 { - return 1.0 - } - return 1 + (s.config.TimeFactorUnderLoad-1)*float64(s.nRunningReqs-1)/float64(s.config.MaxNumSeqs-1) -} - -func (s *VllmSimulator) GetTimeToFirstToken() int { - return int(float64(s.config.TimeToFirstToken) * s.getCurrFactor()) -} - -func (s *VllmSimulator) GetPrefillOverhead() int { - return int(float64(s.config.PrefillOverhead) * s.getCurrFactor()) -} - -func (s *VllmSimulator) GetPrefillTimePerToken() int { - return int(float64(s.config.PrefillTimePerToken) * s.getCurrFactor()) -} - -func (s *VllmSimulator) GetInterTokenLatency() int { - return int(float64(s.config.InterTokenLatency) * s.getCurrFactor()) -} diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 8b78cc58..59f92175 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -18,7 +18,6 @@ package llmdinferencesim import ( "context" - "encoding/json" "errors" "fmt" "io" @@ -26,11 +25,9 @@ import ( "net/http" "os" "strings" - "time" "github.com/llm-d/llm-d-inference-sim/pkg/common" kvcache "github.com/llm-d/llm-d-inference-sim/pkg/kv-cache" - vllmapi "github.com/llm-d/llm-d-inference-sim/pkg/vllm-api" "github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -156,17 +153,7 @@ var _ = Describe("Simulator", func() { client, err := startServer(ctx, mode) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - Model: model, - StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, - } + openaiclient, params := getOpenAIClentAndChatParams(client, model, userMessage, true) stream := openaiclient.Chat.Completions.NewStreaming(ctx, params) defer func() { err := stream.Close() @@ -219,17 +206,7 @@ var _ = Describe("Simulator", func() { client, err := startServer(ctx, mode) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - - params := openai.CompletionNewParams{ - Prompt: openai.CompletionNewParamsPromptUnion{ - OfString: openai.String(userMessage), - }, - Model: openai.CompletionNewParamsModel(model), - StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, - } + openaiclient, params := getOpenAIClentAndCompletionParams(client, model, userMessage, true) stream := openaiclient.Completions.NewStreaming(ctx, params) defer func() { err := stream.Close() @@ -277,16 +254,7 @@ var _ = Describe("Simulator", func() { client, err := startServer(ctx, mode) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - Model: model, - } + openaiclient, params := getOpenAIClentAndChatParams(client, model, userMessage, false) numTokens := 0 // if maxTokens and maxCompletionTokens are passsed // maxCompletionTokens is used @@ -363,15 +331,7 @@ var _ = Describe("Simulator", func() { client, err := startServer(ctx, mode) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - params := openai.CompletionNewParams{ - Prompt: openai.CompletionNewParamsPromptUnion{ - OfString: openai.String(userMessage), - }, - Model: openai.CompletionNewParamsModel(model), - } + openaiclient, params := getOpenAIClentAndCompletionParams(client, model, userMessage, false) numTokens := 0 if maxTokens != 0 { params.MaxTokens = param.NewOpt(int64(maxTokens)) @@ -427,48 +387,9 @@ var _ = Describe("Simulator", func() { Entry(nil, common.ModeEcho, -1), ) - It("Should respond to /health", func() { - ctx := context.TODO() - client, err := startServer(ctx, common.ModeRandom) - Expect(err).NotTo(HaveOccurred()) - - resp, err := client.Get("http://localhost/health") - Expect(err).NotTo(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(http.StatusOK)) - }) - - It("Should respond to /ready", func() { - ctx := context.TODO() - client, err := startServer(ctx, common.ModeRandom) - Expect(err).NotTo(HaveOccurred()) - - resp, err := client.Get("http://localhost/ready") - Expect(err).NotTo(HaveOccurred()) - Expect(resp.StatusCode).To(Equal(http.StatusOK)) - }) - Context("namespace and pod headers", func() { It("Should not include namespace and pod headers in chat completion response when env is not set", func() { - ctx := context.TODO() - - client, err := startServer(ctx, common.ModeRandom) - Expect(err).NotTo(HaveOccurred()) - - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - Model: model, - } - - var httpResp *http.Response - resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) - Expect(err).NotTo(HaveOccurred()) - Expect(resp).NotTo(BeNil()) + httpResp := sendSimpleChatRequest(nil, false) // Check for namespace and pod headers namespaceHeader := httpResp.Header.Get(namespaceHeader) @@ -479,32 +400,13 @@ var _ = Describe("Simulator", func() { }) It("Should include namespace and pod headers in chat completion response", func() { - ctx := context.TODO() - testNamespace := "test-namespace" testPod := "test-pod" envs := map[string]string{ podNameEnv: testPod, podNsEnv: testNamespace, } - client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) - Expect(err).NotTo(HaveOccurred()) - - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - Model: model, - } - - var httpResp *http.Response - resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) - Expect(err).NotTo(HaveOccurred()) - Expect(resp).NotTo(BeNil()) + httpResp := sendSimpleChatRequest(envs, false) // Check for namespace and pod headers namespaceHeader := httpResp.Header.Get(namespaceHeader) @@ -515,33 +417,13 @@ var _ = Describe("Simulator", func() { }) It("Should include namespace and pod headers in chat completion streaming response", func() { - ctx := context.TODO() - testNamespace := "stream-test-namespace" testPod := "stream-test-pod" envs := map[string]string{ podNameEnv: testPod, podNsEnv: testNamespace, } - client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) - Expect(err).NotTo(HaveOccurred()) - - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - Model: model, - StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, - } - - var httpResp *http.Response - resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) - Expect(err).NotTo(HaveOccurred()) - Expect(resp).NotTo(BeNil()) + httpResp := sendSimpleChatRequest(envs, true) // Check for namespace and pod headers namespaceHeader := httpResp.Header.Get(namespaceHeader) @@ -552,27 +434,7 @@ var _ = Describe("Simulator", func() { }) It("Should not include namespace and pod headers in chat completion streaming response when env is not set", func() { - ctx := context.TODO() - - client, err := startServer(ctx, common.ModeRandom) - Expect(err).NotTo(HaveOccurred()) - - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - Model: model, - StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, - } - - var httpResp *http.Response - resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) - Expect(err).NotTo(HaveOccurred()) - Expect(resp).NotTo(BeNil()) + httpResp := sendSimpleChatRequest(nil, true) // Check for namespace and pod headers namespaceHeader := httpResp.Header.Get(namespaceHeader) @@ -594,16 +456,7 @@ var _ = Describe("Simulator", func() { client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - - params := openai.CompletionNewParams{ - Prompt: openai.CompletionNewParamsPromptUnion{ - OfString: openai.String(userMessage), - }, - Model: openai.CompletionNewParamsModel(model), - } + openaiclient, params := getOpenAIClentAndCompletionParams(client, model, userMessage, false) var httpResp *http.Response resp, err := openaiclient.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) Expect(err).NotTo(HaveOccurred()) @@ -629,17 +482,7 @@ var _ = Describe("Simulator", func() { client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - - params := openai.CompletionNewParams{ - Prompt: openai.CompletionNewParamsPromptUnion{ - OfString: openai.String(userMessage), - }, - Model: openai.CompletionNewParamsModel(model), - StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, - } + openaiclient, params := getOpenAIClentAndCompletionParams(client, model, userMessage, true) var httpResp *http.Response resp, err := openaiclient.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) Expect(err).NotTo(HaveOccurred()) @@ -686,19 +529,10 @@ var _ = Describe("Simulator", func() { Expect(string(body)).To(ContainSubstring("BadRequestError")) // Also test with OpenAI client to ensure it gets an error - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client), - ) - - _, err = openaiclient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("This is a test message"), - }, - Model: model, - MaxTokens: openai.Int(8), - }) + openaiclient, params := getOpenAIClentAndChatParams(client, model, "This is a test message", false) + params.MaxTokens = openai.Int(8) + _, err = openaiclient.Chat.Completions.New(ctx, params) Expect(err).To(HaveOccurred()) var apiErr *openai.Error Expect(errors.As(err, &apiErr)).To(BeTrue()) @@ -712,19 +546,11 @@ var _ = Describe("Simulator", func() { client, err := startServerWithArgs(ctx, common.ModeEcho, args, nil) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client), - ) + openaiclient, params := getOpenAIClentAndChatParams(client, model, "Hello", false) + params.MaxTokens = openai.Int(5) // Send a request within the context window - resp, err := openaiclient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("Hello"), - }, - Model: model, - MaxTokens: openai.Int(5), - }) + resp, err := openaiclient.Chat.Completions.New(ctx, params) Expect(err).NotTo(HaveOccurred()) Expect(resp.Choices).To(HaveLen(1)) @@ -760,389 +586,59 @@ var _ = Describe("Simulator", func() { Expect(string(body)).To(ContainSubstring("BadRequestError")) }) }) +}) - Describe("Check random latencies", Ordered, func() { - var simulator *VllmSimulator - - BeforeAll(func() { - var err error - simulator, err = New(klog.Background()) - Expect(err).NotTo(HaveOccurred()) - - simulator.config = &common.Configuration{ - TimeToFirstToken: 2048, - TimeToFirstTokenStdDev: 2048, - KVCacheTransferLatency: 2048, - KVCacheTransferLatencyStdDev: 2048, - } - - common.InitRandom(time.Now().UnixNano()) - }) - - DescribeTable("should calculate inter token latency correctly", - func(interTokenLatency int, stddev int) { - simulator.config.InterTokenLatency = interTokenLatency - simulator.config.InterTokenLatencyStdDev = stddev - interToken := simulator.getInterTokenLatency() - Expect(interToken).To(BeNumerically(">=", int(float32(interTokenLatency)*0.3))) - Expect(interToken).To(BeNumerically("<=", int(float32(interTokenLatency)*1.7))) - }, - func(interTokenLatency int, stddev int) string { - return fmt.Sprintf("interTokenLatency: %d stddev: %d", interTokenLatency, stddev) - }, - Entry(nil, 1000, 300), - Entry(nil, 1000, 800), // invalid std dev, used for testing purposes - Entry(nil, 1000, 900), // invalid std dev, used for testing purposes - Entry(nil, 1000, 0), - ) - - DescribeTable("should calculate total inter token latency correctly", - func(interTokenLatency int, stddev int, numberOfTokens int) { - simulator.config.InterTokenLatency = interTokenLatency - simulator.config.InterTokenLatencyStdDev = stddev - simulator.config.MaxNumSeqs = 1 - simulator.config.TimeFactorUnderLoad = 1.0 - - latency := 0 - for range numberOfTokens - 1 { - latency += simulator.getInterTokenLatency() - } - - Expect(latency).To(BeNumerically(">=", int(float32(interTokenLatency)*0.3*float32(numberOfTokens)))) - Expect(latency).To(BeNumerically("<=", int(float32(interTokenLatency)*1.7*float32(numberOfTokens)))) - }, - func(interTokenLatency int, stddev int, numberOfTokens int) string { - return fmt.Sprintf("interTokenLatency: %d stddev: %d, numberOfTokens: %d", interTokenLatency, - stddev, numberOfTokens) - }, - Entry(nil, 1000, 30, 100), - Entry(nil, 1000, 800, 20), // invalid std dev, used for testing purposes - Entry(nil, 1000, 900, 5), // invalid std dev, used for testing purposes - Entry(nil, 1000, 0, 50), - ) - - DescribeTable("should calculate time to first token correctly", - func(timeToFirstToken int, timeToFirstTokenStdDev int, - kvCacheLatency int, kvCacheLatencyStdDev int, doREmotePrefill bool) { - simulator.config.TimeToFirstToken = timeToFirstToken - simulator.config.TimeToFirstTokenStdDev = timeToFirstTokenStdDev - simulator.config.KVCacheTransferLatency = kvCacheLatency - simulator.config.KVCacheTransferLatencyStdDev = kvCacheLatencyStdDev - timeToFirst := simulator.getTimeToFirstToken(1, 0, doREmotePrefill) - if doREmotePrefill { - Expect(timeToFirst).To(BeNumerically(">=", int(float32(kvCacheLatency)*0.3))) - Expect(timeToFirst).To(BeNumerically("<=", int(float32(kvCacheLatency)*1.7))) - } else { - Expect(timeToFirst).To(BeNumerically(">=", int(float32(timeToFirstToken)*0.3))) - Expect(timeToFirst).To(BeNumerically("<=", int(float32(timeToFirstToken)*1.7))) - } - }, - func(timeToFirstToken int, timeToFirstTokenStdDev int, - kvCacheLatency int, kvCacheLatencyStdDev int, doREmotePrefill bool) string { - return fmt.Sprintf("timeToFirstToken: %d stddev: %d kvCacheLatency: %d stddev: %d doREmotePrefill: %t", - timeToFirstToken, timeToFirstTokenStdDev, kvCacheLatency, kvCacheLatencyStdDev, doREmotePrefill) - }, - Entry(nil, 10000, 300, 1000, 200, true), - Entry(nil, 10000, 300, 1000, 200, false), - Entry(nil, 10000, 9000, 1000, 800, true), // invalid std dev, used for testing purposes - Entry(nil, 10000, 8000, 1000, 900, false), // invalid std dev, used for testing purposes - Entry(nil, 10000, 0, 1000, 0, true), - Entry(nil, 10000, 0, 1000, 0, false), - ) - - It("when is not 0, ignore ", func() { - timeToFirstToken := 1000 - simulator.config.TimeToFirstToken = timeToFirstToken - simulator.config.TimeToFirstTokenStdDev = 0 - - simulator.config.PrefillOverhead = 100 - simulator.config.PrefillTimePerToken = 200 - simulator.config.PrefillTimeStdDev = 80 - - ttft := simulator.getTimeToFirstToken(128, 0, false) - - Expect(ttft).To(BeNumerically("==", timeToFirstToken)) - }) - - It("when is 0, and is not 0, use ", func() { - simulator.config.TimeToFirstToken = 0 - simulator.config.TimeToFirstTokenStdDev = 0 - - simulator.config.PrefillOverhead = 100 - simulator.config.PrefillTimePerToken = 200 - simulator.config.PrefillTimeStdDev = 80 - - ttft := simulator.getTimeToFirstToken(128, 0, false) - Expect(ttft).NotTo(BeNumerically("==", 0)) - }) - - DescribeTable("time to first token is against number of prompt tokens with std", - func(prefillOverhead int, prefillTimePerToken int, stdDev int, nTokens int, nCachedTokens int) { - simulator.config.TimeToFirstToken = 0 - simulator.config.PrefillOverhead = prefillOverhead - simulator.config.PrefillTimePerToken = prefillTimePerToken - simulator.config.PrefillTimeStdDev = stdDev - - ttft := simulator.getTimeToFirstToken(nTokens, nCachedTokens, false) - - expectedTTFT := prefillOverhead + prefillTimePerToken*(nTokens-nCachedTokens) - Expect(ttft).To(BeNumerically(">=", int(float64(expectedTTFT)*0.3))) - Expect(ttft).To(BeNumerically("<=", int(float64(expectedTTFT)*1.7))) - }, - func(prefillOverhead int, prefillTimePerToken, stdDev int, nTokens int, nCachedTokens int) string { - return fmt.Sprintf("prefillOverhead: %d, prefillTimePerToken: %d, stdDev: %d, nTokens: %d nCachedTokens: %d", - prefillOverhead, prefillTimePerToken, stdDev, nTokens, nCachedTokens) - }, - Entry("single token", 100, 50, 10, 1, 0), - Entry("single token big std", 100, 50, 70, 1, 0), - Entry("stddev is 0", 100, 50, 0, 1, 0), - Entry("medium overhead, 512 tokens", 200, 1000, 150, 512, 0), - Entry("large overhead, 1024 tokens", 2000, 3000, 800, 1024, 0), - Entry("very long prompt", 150, 200, 70, 20000, 0), - Entry("medium overhead, 512 tokens, 256 cached", 200, 1000, 150, 512, 256), - Entry("large overhead, 1024 tokens, 1008 cached", 2000, 3000, 800, 1024, 1008), - Entry("very long prompt, 1024 cached", 150, 200, 70, 20000, 1024), - ) - - DescribeTable("time to first token is against number of prompt tokens", - func(prefillOverhead int, prefillTimePerToken int, nTokens int, nCachedTokens int) { - simulator.config.TimeToFirstToken = 0 - simulator.config.PrefillOverhead = prefillOverhead - simulator.config.PrefillTimePerToken = prefillTimePerToken - simulator.config.PrefillTimeStdDev = 0 - - ttft := simulator.getTimeToFirstToken(nTokens, nCachedTokens, false) - expectedTTFT := prefillOverhead + prefillTimePerToken*(nTokens-nCachedTokens) - Expect(ttft).To(Equal(expectedTTFT)) - }, - func(prefillOverhead int, prefillTimePerToken, nTokens int, nCachedTokens int) string { - return fmt.Sprintf("prefillOverhead: %d, prefillTimePerToken: %d, nTokens: %d nCachedTokens: %d", - prefillOverhead, prefillTimePerToken, nTokens, nCachedTokens) - }, - Entry("single token", 100, 50, 1, 0), - Entry("medium overhead, 512 tokens", 200, 1000, 512, 0), - Entry("large overhead, 1024 tokens", 2000, 3000, 1024, 0), - Entry("very long prompt", 150, 200, 20000, 0), - Entry("medium overhead, 512 tokens, 256 cached", 200, 1000, 512, 256), - Entry("large overhead, 1024 tokens, 128 cached", 2000, 3000, 1024, 128), - Entry("very long prompt, 1024 cached", 150, 200, 20000, 1024), - ) - - It("when not 0, ignore ", func() { - simulator.config.KVCacheTransferLatency = 200 - simulator.config.KVCacheTransferLatencyStdDev = 0 - - simulator.config.KVCacheTransferTimePerToken = 100 - simulator.config.KVCacheTransferTimeStdDev = 0 - - ttft := simulator.getTimeToFirstToken(128, 0, true) - Expect(ttft).To(BeNumerically("==", 200)) - }) - - It("when is 0, and is not 0, use ", func() { - simulator.config.KVCacheTransferLatency = 0 - simulator.config.KVCacheTransferLatencyStdDev = 0 - - simulator.config.KVCacheTransferTimePerToken = 100 - simulator.config.KVCacheTransferTimeStdDev = 0 - - ttft := simulator.getTimeToFirstToken(128, 0, true) - Expect(ttft).To(BeNumerically("==", 12800)) - }) - - DescribeTable("kv cache transfer time against number of prompt tokens", - func(kvCacheTransTPT int, stddev int, nTokens int) { - simulator.config.TimeToFirstToken = 0 - simulator.config.PrefillOverhead = 1 - simulator.config.KVCacheTransferTimePerToken = kvCacheTransTPT - simulator.config.KVCacheTransferTimeStdDev = stddev - - ttft := simulator.getTimeToFirstToken(nTokens, 0, true) - - expectedTTFT := kvCacheTransTPT * nTokens - Expect(ttft).To(BeNumerically(">=", int(float64(expectedTTFT)*0.3))) - Expect(ttft).To(BeNumerically("<=", int(float64(expectedTTFT)*1.7))) - - }, - func(kvCacheTransferTimePerToken int, stddev int, nTokens int) string { - return fmt.Sprintf("kvCacheTransferTimePerToken: %d stddev: %d nTokens: %d", - kvCacheTransferTimePerToken, stddev, nTokens) - }, - Entry("single token", 100, 70, 1), - Entry("stddev is 0", 100, 0, 1), - Entry("medium overhead, 512 tokens", 200, 150, 512), - Entry("large overhead, 1024 tokens", 2000, 1800, 1024), - Entry("very long prompt", 150, 100, 20000), - ) - - It("when time-factor-under-load is 1, the time to first token should be equal to time-to-first-token", func() { - simulator.config.TimeToFirstToken = 42 - simulator.config.TimeToFirstTokenStdDev = 0 - simulator.config.TimeFactorUnderLoad = 1.0 - - simulator.runReqChan <- 100 - - ttft := simulator.getTimeToFirstToken(128, 0, false) - Expect(ttft).To(Equal(42)) - }) - - It("when time-factor-under-load is > 1, but max-num-seqs is 1, the factor will not take effect", func() { - simulator.config.TimeToFirstToken = 42 - simulator.config.TimeToFirstTokenStdDev = 0 - simulator.config.TimeFactorUnderLoad = 100.0 - simulator.config.MaxNumSeqs = 1 - - for len(simulator.runReqChan) > 0 { - <-simulator.runReqChan - } - - simulator.runReqChan <- 1 - - ttft := simulator.getTimeToFirstToken(128, 0, false) - Expect(ttft).To(Equal(42)) - }) - - DescribeTable("when time-factor-under-load is > 1, and the sim is fully loaded, the time to first token should be time-factor-under-load * time-to-first-token", - func(timeFactorUnderLoad float64, maxNumOfReq int) { - simulator.config.TimeToFirstToken = 42 - simulator.config.TimeToFirstTokenStdDev = 0 - simulator.config.TimeFactorUnderLoad = timeFactorUnderLoad - simulator.config.MaxNumSeqs = maxNumOfReq - simulator.nRunningReqs = int64(maxNumOfReq) - - ttft := simulator.getTimeToFirstToken(128, 0, false) - Expect(ttft).To(Equal(int(float64(42) * timeFactorUnderLoad))) - - }, - func(timeFactorUnderLoad float64, maxNumOfReq int64) string { - return fmt.Sprintf("timeFactorUnderLoad: %f maxNumOfReq: %d", - timeFactorUnderLoad, maxNumOfReq) - }, - - Entry("factor: 1.5", 1.5, 70), - Entry("factor: 2.0", 2.0, 2), - Entry("factor: 100.0", 100.0, 150), - Entry("factor: 20000.0", 20000.0, 310), - ) - - DescribeTable("when time-factor-under-load is > 1, and the sim is partially loaded, the time to first token should be linear interpolation between time-to-first-token and time-factor-under-load * time-to-first-token", - func(timeFactorUnderLoad float64, maxNumOfReq int, nCurrNumOfReq int) { - simulator.config.TimeToFirstToken = 42 - simulator.config.TimeToFirstTokenStdDev = 0 - simulator.config.TimeFactorUnderLoad = timeFactorUnderLoad - simulator.config.MaxNumSeqs = maxNumOfReq - simulator.nRunningReqs = int64(nCurrNumOfReq) - - ttft := simulator.getTimeToFirstToken(128, 0, false) - max := timeFactorUnderLoad * float64(42) - Expect(ttft).To(BeNumerically(">=", 42)) - Expect(ttft).To(BeNumerically("<=", max)) - - }, - func(timeFactorUnderLoad float64, maxNumOfReq int, nCurrNumOfReq int) string { - return fmt.Sprintf("timeFactorUnderLoad: %f maxNumOfReq: %d nCurrNumOfReq: %d", - timeFactorUnderLoad, maxNumOfReq, nCurrNumOfReq) - }, - - Entry("factor: 1.5", 1.5, 70, 35), - Entry("factor: 2.0", 2.0, 2, 1), - Entry("factor: 100.0", 100.0, 150, 75), - Entry("factor: 20000.0", 20000.0, 310, 155), - ) - - It("when TimeFactorUnderLoad is 1.0, calcLoadFactor should give 1", func() { - simulator.config.TimeFactorUnderLoad = 1.0 - simulator.config.MaxNumSeqs = 11 - simulator.nRunningReqs = 3 - - factor := simulator.getCurrFactor() - Expect(factor).To(BeNumerically("==", 1.0)) - }) - - It("when TimeFactorUnderLoad is > 1.0, and sim is fully loaded, calcLoadFactor should give TimeFactorUnderLoad", func() { - simulator.config.TimeFactorUnderLoad = 2.0 - simulator.config.MaxNumSeqs = 11 - simulator.nRunningReqs = 11 - - factor := simulator.getCurrFactor() - Expect(factor).To(BeNumerically("==", simulator.config.TimeFactorUnderLoad)) - - }) - - It("when TimeFactorUnderLoad is > 1.0, and sim is partially loaded, calcLoadFactor should give a value between 1 and TimeFactorUnderLoad", func() { - simulator.config.TimeFactorUnderLoad = 2.0 - simulator.config.MaxNumSeqs = 11 - simulator.nRunningReqs = 6 - - factor := simulator.getCurrFactor() - Expect(factor).To(BeNumerically(">", 1.0)) - Expect(factor).To(BeNumerically("<", simulator.config.TimeFactorUnderLoad)) - }) - }) - - Context("tokenize", Ordered, func() { - tmpDir := "./tests-tmp/" - AfterAll(func() { - err := os.RemoveAll(tmpDir) - Expect(err).NotTo(HaveOccurred()) - }) +func sendSimpleChatRequest(envs map[string]string, streaming bool) *http.Response { + ctx := context.TODO() - It("Should return correct response to /tokenize chat", func() { - ctx := context.TODO() - args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom, - "--tokenizers-cache-dir", tmpDir, "--max-model-len", "2048"} - client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) - Expect(err).NotTo(HaveOccurred()) + client, err := startServerWithArgs(ctx, common.ModeRandom, nil, envs) + Expect(err).NotTo(HaveOccurred()) - reqBody := `{ - "messages": [{"role": "user", "content": "This is a test"}], - "model": "Qwen/Qwen2-0.5B" - }` - resp, err := client.Post("http://localhost/tokenize", "application/json", strings.NewReader(reqBody)) - Expect(err).NotTo(HaveOccurred()) - defer func() { - err := resp.Body.Close() - Expect(err).NotTo(HaveOccurred()) - }() + openaiclient, params := getOpenAIClentAndChatParams(client, model, userMessage, streaming) + var httpResp *http.Response + resp, err := openaiclient.Chat.Completions.New(ctx, params, option.WithResponseInto(&httpResp)) + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) - body, err := io.ReadAll(resp.Body) - Expect(err).NotTo(HaveOccurred()) + Expect(resp.Choices).ShouldNot(BeEmpty()) + Expect(string(resp.Object)).To(Equal(chatCompletionObject)) - var tokenizeResp vllmapi.TokenizeResponse - err = json.Unmarshal(body, &tokenizeResp) - Expect(err).NotTo(HaveOccurred()) - Expect(tokenizeResp.Count).To(Equal(4)) - Expect(tokenizeResp.Tokens).To(HaveLen(4)) - Expect(tokenizeResp.MaxModelLen).To(Equal(2048)) - }) + return httpResp +} - It("Should return correct response to /tokenize text", func() { - ctx := context.TODO() - args := []string{"cmd", "--model", qwenModelName, "--mode", common.ModeRandom, - "--tokenizers-cache-dir", tmpDir, "--max-model-len", "2048"} - client, err := startServerWithArgs(ctx, common.ModeRandom, args, nil) - Expect(err).NotTo(HaveOccurred()) +func getOpenAIClentAndChatParams(client option.HTTPClient, model string, message string, + streaming bool) (openai.Client, openai.ChatCompletionNewParams) { + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) - reqBody := `{ - "prompt": "This is a test", - "model": "Qwen/Qwen2-0.5B" - }` - resp, err := client.Post("http://localhost/tokenize", "application/json", strings.NewReader(reqBody)) - Expect(err).NotTo(HaveOccurred()) - defer func() { - err := resp.Body.Close() - Expect(err).NotTo(HaveOccurred()) - }() + params := openai.ChatCompletionNewParams{ + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(message), + }, + Model: model, + } + if streaming { + params.StreamOptions = openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)} + } + return openaiclient, params +} - body, err := io.ReadAll(resp.Body) - Expect(err).NotTo(HaveOccurred()) +// nolint +func getOpenAIClentAndCompletionParams(client option.HTTPClient, model string, message string, + streaming bool) (openai.Client, openai.CompletionNewParams) { + openaiclient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client)) - var tokenizeResp vllmapi.TokenizeResponse - err = json.Unmarshal(body, &tokenizeResp) - Expect(err).NotTo(HaveOccurred()) - Expect(tokenizeResp.Count).To(Equal(4)) - Expect(tokenizeResp.Tokens).To(HaveLen(4)) - Expect(tokenizeResp.MaxModelLen).To(Equal(2048)) - }) - }) -}) + params := openai.CompletionNewParams{ + Prompt: openai.CompletionNewParamsPromptUnion{ + OfString: openai.String(message), + }, + Model: openai.CompletionNewParamsModel(model), + } + if streaming { + params.StreamOptions = openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)} + } + return openaiclient, params +} diff --git a/pkg/llm-d-inference-sim/streaming.go b/pkg/llm-d-inference-sim/streaming.go index ea9b6676..2508298d 100644 --- a/pkg/llm-d-inference-sim/streaming.go +++ b/pkg/llm-d-inference-sim/streaming.go @@ -100,7 +100,7 @@ func (s *VllmSimulator) sendStreamingResponse(context *streamingContext, respons func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writer, genTokens []string, tc *openaiserverapi.ToolCall, finishReason string) { // time to first token delay - ttft := s.getTimeToFirstToken(context.nPromptTokens, context.nCachedPromptTokens, context.doRemotePrefill) + ttft := s.getWaitTimeToFirstToken(context.nPromptTokens, context.nCachedPromptTokens, context.doRemotePrefill) time.Sleep(time.Duration(ttft) * time.Millisecond) for i, token := range genTokens { diff --git a/pkg/llm-d-inference-sim/tools_test.go b/pkg/llm-d-inference-sim/tools_test.go index c996db59..ae22a7f6 100644 --- a/pkg/llm-d-inference-sim/tools_test.go +++ b/pkg/llm-d-inference-sim/tools_test.go @@ -345,19 +345,10 @@ var _ = Describe("Simulator for request with tools", func() { client, err := startServer(ctx, mode) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) + openaiclient, params := getOpenAIClentAndChatParams(client, model, userMessage, true) + params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")} + params.Tools = tools - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage(userMessage), - }, - Model: model, - StreamOptions: openai.ChatCompletionStreamOptionsParam{IncludeUsage: param.NewOpt(true)}, - ToolChoice: openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")}, - Tools: tools, - } stream := openaiclient.Chat.Completions.NewStreaming(ctx, params) defer func() { err := stream.Close() @@ -436,16 +427,9 @@ var _ = Describe("Simulator for request with tools", func() { client, err := startServer(ctx, mode) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage(userMessage)}, - Model: model, - ToolChoice: openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")}, - Tools: tools, - } + openaiclient, params := getOpenAIClentAndChatParams(client, model, userMessage, false) + params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")} + params.Tools = tools resp, err := openaiclient.Chat.Completions.New(ctx, params) Expect(err).NotTo(HaveOccurred()) @@ -528,16 +512,9 @@ var _ = Describe("Simulator for request with tools", func() { client, err := startServerWithArgs(ctx, common.ModeEcho, serverArgs, nil) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage(userMessage)}, - Model: model, - ToolChoice: openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")}, - Tools: toolWithArray, - } + openaiclient, params := getOpenAIClentAndChatParams(client, model, userMessage, false) + params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")} + params.Tools = toolWithArray resp, err := openaiclient.Chat.Completions.New(ctx, params) Expect(err).NotTo(HaveOccurred()) @@ -584,16 +561,9 @@ var _ = Describe("Simulator for request with tools", func() { client, err := startServer(ctx, mode) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage(userMessage)}, - Model: model, - ToolChoice: openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")}, - Tools: toolWith3DArray, - } + openaiclient, params := getOpenAIClentAndChatParams(client, model, userMessage, false) + params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")} + params.Tools = toolWith3DArray resp, err := openaiclient.Chat.Completions.New(ctx, params) Expect(err).NotTo(HaveOccurred()) @@ -644,16 +614,9 @@ var _ = Describe("Simulator for request with tools", func() { client, err := startServer(ctx, mode) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage(userMessage)}, - Model: model, - ToolChoice: openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")}, - Tools: toolWithWrongMinMax, - } + openaiclient, params := getOpenAIClentAndChatParams(client, model, userMessage, false) + params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")} + params.Tools = toolWithWrongMinMax _, err = openaiclient.Chat.Completions.New(ctx, params) Expect(err).To(HaveOccurred()) @@ -670,16 +633,9 @@ var _ = Describe("Simulator for request with tools", func() { client, err := startServer(ctx, mode) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage(userMessage)}, - Model: model, - ToolChoice: openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")}, - Tools: toolWithObjects, - } + openaiclient, params := getOpenAIClentAndChatParams(client, model, userMessage, false) + params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")} + params.Tools = toolWithObjects resp, err := openaiclient.Chat.Completions.New(ctx, params) Expect(err).NotTo(HaveOccurred()) @@ -732,16 +688,9 @@ var _ = Describe("Simulator for request with tools", func() { client, err := startServer(ctx, mode) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage(userMessage)}, - Model: model, - ToolChoice: openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")}, - Tools: toolWithObjectAndArray, - } + openaiclient, params := getOpenAIClentAndChatParams(client, model, userMessage, false) + params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")} + params.Tools = toolWithObjectAndArray resp, err := openaiclient.Chat.Completions.New(ctx, params) Expect(err).NotTo(HaveOccurred()) @@ -790,16 +739,9 @@ var _ = Describe("Simulator for request with tools", func() { client, err := startServerWithArgs(ctx, common.ModeEcho, serverArgs, nil) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage(userMessage)}, - Model: model, - ToolChoice: openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")}, - Tools: toolWithoutRequiredParams, - } + openaiclient, params := getOpenAIClentAndChatParams(client, model, userMessage, false) + params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")} + params.Tools = toolWithoutRequiredParams resp, err := openaiclient.Chat.Completions.New(ctx, params) Expect(err).NotTo(HaveOccurred()) @@ -835,16 +777,9 @@ var _ = Describe("Simulator for request with tools", func() { client, err := startServerWithArgs(ctx, common.ModeEcho, serverArgs, nil) Expect(err).NotTo(HaveOccurred()) - openaiclient := openai.NewClient( - option.WithBaseURL(baseURL), - option.WithHTTPClient(client)) - - params := openai.ChatCompletionNewParams{ - Messages: []openai.ChatCompletionMessageParamUnion{openai.UserMessage(userMessage)}, - Model: model, - ToolChoice: openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")}, - Tools: toolWithObjectWithoutRequiredParams, - } + openaiclient, params := getOpenAIClentAndChatParams(client, model, userMessage, false) + params.ToolChoice = openai.ChatCompletionToolChoiceOptionUnionParam{OfAuto: param.NewOpt("required")} + params.Tools = toolWithObjectWithoutRequiredParams resp, err := openaiclient.Chat.Completions.New(ctx, params) Expect(err).NotTo(HaveOccurred())