Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 12 additions & 72 deletions pkg/llm-d-inference-sim/failures_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/openai/openai-go"
"github.com/openai/openai-go/option"

"github.com/llm-d/llm-d-inference-sim/pkg/common"
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
Expand Down Expand Up @@ -135,18 +134,8 @@ var _ = Describe("Failures", func() {
})

It("should always return an error response for chat completions", func() {
openaiClient := openai.NewClient(
option.WithBaseURL(baseURL),
option.WithHTTPClient(client),
)

_, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
Model: model,
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage(userMessage),
},
})

openaiClient, params := getOpenAIClentAndChatParams(client, model, userMessage, false)
_, err := openaiClient.Chat.Completions.New(ctx, params)
Expect(err).To(HaveOccurred())

var openaiError *openai.Error
Expand All @@ -158,18 +147,8 @@ var _ = Describe("Failures", func() {
})

It("should always return an error response for text completions", func() {
openaiClient := openai.NewClient(
option.WithBaseURL(baseURL),
option.WithHTTPClient(client),
)

_, err := openaiClient.Completions.New(ctx, openai.CompletionNewParams{
Model: openai.CompletionNewParamsModel(model),
Prompt: openai.CompletionNewParamsPromptUnion{
OfString: openai.String(userMessage),
},
})

openaiClient, params := getOpenAIClentAndChatParams(client, model, userMessage, false)
_, err := openaiClient.Chat.Completions.New(ctx, params)
Expect(err).To(HaveOccurred())

var openaiError *openai.Error
Expand All @@ -194,18 +173,8 @@ var _ = Describe("Failures", func() {
})

It("should return only rate limit errors", func() {
openaiClient := openai.NewClient(
option.WithBaseURL(baseURL),
option.WithHTTPClient(client),
)

_, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
Model: model,
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage(userMessage),
},
})

openaiClient, params := getOpenAIClentAndChatParams(client, model, userMessage, false)
_, err := openaiClient.Chat.Completions.New(ctx, params)
Expect(err).To(HaveOccurred())

var openaiError *openai.Error
Expand All @@ -230,20 +199,11 @@ var _ = Describe("Failures", func() {
})

It("should return only specified error types", func() {
openaiClient := openai.NewClient(
option.WithBaseURL(baseURL),
option.WithHTTPClient(client),
)
openaiClient, params := getOpenAIClentAndChatParams(client, model, userMessage, false)

// Make multiple requests to verify we get the expected error types
for i := 0; i < 10; i++ {
_, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
Model: model,
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage(userMessage),
},
})

_, err := openaiClient.Chat.Completions.New(ctx, params)
Expect(err).To(HaveOccurred())

var openaiError *openai.Error
Expand All @@ -270,18 +230,8 @@ var _ = Describe("Failures", func() {
})

It("should never return errors and behave like random mode", func() {
openaiClient := openai.NewClient(
option.WithBaseURL(baseURL),
option.WithHTTPClient(client),
)

resp, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
Model: model,
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage(userMessage),
},
})

openaiClient, params := getOpenAIClentAndChatParams(client, model, userMessage, false)
resp, err := openaiClient.Chat.Completions.New(ctx, params)
Expect(err).ToNot(HaveOccurred())
Expect(resp.Choices).To(HaveLen(1))
Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty())
Expand All @@ -300,18 +250,8 @@ var _ = Describe("Failures", func() {
}, nil)
Expect(err).ToNot(HaveOccurred())

openaiClient := openai.NewClient(
option.WithBaseURL(baseURL),
option.WithHTTPClient(client),
)

_, err = openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{
Model: model,
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage(userMessage),
},
})

openaiClient, params := getOpenAIClentAndChatParams(client, model, userMessage, false)
_, err = openaiClient.Chat.Completions.New(ctx, params)
Expect(err).To(HaveOccurred())

var openaiError *openai.Error
Expand Down
94 changes: 94 additions & 0 deletions pkg/llm-d-inference-sim/helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
Copyright 2025 The llm-d-inference-sim Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

// Package vllmsim implements the vLLM simulator.
package llmdinferencesim

import (
"encoding/json"
"fmt"
)

// isValidModel checks if the given model is the base model or one of "loaded" LoRAs
func (s *VllmSimulator) isValidModel(model string) bool {
for _, name := range s.config.ServedModelNames {
if model == name {
return true
}
}
for _, lora := range s.getLoras() {
if model == lora {
return true
}
}

return false
}

// isLora returns true if the given model name is one of loaded LoRAs
func (s *VllmSimulator) isLora(model string) bool {
for _, lora := range s.getLoras() {
if model == lora {
return true
}
}

return false
}

// getDisplayedModelName returns the model name that must appear in API
// responses. LoRA adapters keep their explicit name, while all base-model
// requests are surfaced as the first alias from --served-model-name.
func (s *VllmSimulator) getDisplayedModelName(reqModel string) string {
if s.isLora(reqModel) {
return reqModel
}
return s.config.ServedModelNames[0]
}

func (s *VllmSimulator) showConfig(dp bool) error {
cfgJSON, err := json.Marshal(s.config)
if err != nil {
return fmt.Errorf("failed to marshal configuration to JSON: %w", err)
}

var m map[string]interface{}
err = json.Unmarshal(cfgJSON, &m)
if err != nil {
return fmt.Errorf("failed to unmarshal JSON to map: %w", err)
}
if dp {
// remove the port
delete(m, "port")
}
// clean LoraModulesString field
m["lora-modules"] = m["LoraModules"]
delete(m, "LoraModules")
delete(m, "LoraModulesString")

// clean fake-metrics field
if field, ok := m["fake-metrics"].(map[string]interface{}); ok {
delete(field, "LorasString")
}

// show in JSON
cfgJSON, err = json.MarshalIndent(m, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal configuration to JSON: %w", err)
}
s.logger.Info("Configuration:", "", string(cfgJSON))
return nil
}
65 changes: 65 additions & 0 deletions pkg/llm-d-inference-sim/latencies.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
Copyright 2025 The llm-d-inference-sim Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

// Package vllmsim implements the vLLM simulator.
package llmdinferencesim

import "github.com/llm-d/llm-d-inference-sim/pkg/common"

func (s *VllmSimulator) getCurrLoadFactor() float64 {
if s.config.MaxNumSeqs <= 1 {
return 1.0
}
return 1 + (s.config.TimeFactorUnderLoad-1)*float64(s.nRunningReqs-1)/float64(s.config.MaxNumSeqs-1)
}

func (s *VllmSimulator) getTimeToFirstToken() int {
return int(float64(s.config.TimeToFirstToken) * s.getCurrLoadFactor())
}

func (s *VllmSimulator) getPrefillOverhead() int {
return int(float64(s.config.PrefillOverhead) * s.getCurrLoadFactor())
}

func (s *VllmSimulator) getPrefillTimePerToken() int {
return int(float64(s.config.PrefillTimePerToken) * s.getCurrLoadFactor())
}

// returns time to first token based on the current request's doRemotePrefill
func (s *VllmSimulator) getWaitTimeToFirstToken(nPromptTokens int, nCachedPromptTokens int, doRemotePrefill bool) int {
if doRemotePrefill {
if s.config.KVCacheTransferLatency == 0 && s.config.KVCacheTransferLatencyStdDev == 0 {
// is disaggregated PD and ttft is calculated using number of prompt tokens
kvCacheTransT := s.config.KVCacheTransferTimePerToken * nPromptTokens
return common.RandomNorm(kvCacheTransT, s.config.KVCacheTransferTimeStdDev)
}
// is disaggregated PD and *not* using number of prompt tokens
return common.RandomNorm(s.config.KVCacheTransferLatency, s.config.KVCacheTransferLatencyStdDev)
}
if s.config.TimeToFirstToken == 0 && s.config.TimeToFirstTokenStdDev == 0 {
// is aggregated PD and ttft is calculated using number of prompt tokens that are not in kv cache
prefillTime := s.getPrefillOverhead() + (nPromptTokens-nCachedPromptTokens)*s.getPrefillTimePerToken()
return common.RandomNorm(prefillTime, s.config.PrefillTimeStdDev)
}
// is aggregated PD and *not* using number of prompt tokens
return common.RandomNorm(s.getTimeToFirstToken(), s.config.TimeToFirstTokenStdDev)
}

// returns inter token latency
func (s *VllmSimulator) getInterTokenLatency() int {
latency := int(float64(s.config.InterTokenLatency) * s.getCurrLoadFactor())
return common.RandomNorm(latency, s.config.InterTokenLatencyStdDev)
}
Loading
Loading