Skip to content

Commit 4ae89f2

Browse files
authored
Merge branch 'main' into dev/prefill-overhead
Signed-off-by: Qifan Deng <[email protected]>
2 parents 9886b94 + b98882a commit 4ae89f2

File tree

13 files changed

+1403
-310
lines changed

13 files changed

+1403
-310
lines changed

.github/workflows/re-run-action.yml

Lines changed: 0 additions & 16 deletions
This file was deleted.

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ For more details see the <a href="https://docs.vllm.ai/en/stable/getting_started
135135
- `zmq-max-connect-attempts`: the maximum number of ZMQ connection attempts, defaults to 0, maximum: 10
136136
- `event-batch-size`: the maximum number of kv-cache events to be sent together, defaults to 16
137137
---
138+
- `failure-injection-rate`: probability (0-100) of injecting failures, optional, default is 0
139+
- `failure-types`: list of specific failure types to inject (rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found), optional, if empty all types are used
140+
---
138141
- `fake-metrics`: represents a predefined set of metrics to be sent to Prometheus as a substitute for the real metrics. When specified, only these fake metrics will be reported — real metrics and fake metrics will never be reported together. The set should include values for
139142
- `running-requests`
140143
- `waiting-requests`
@@ -143,7 +146,6 @@ For more details see the <a href="https://docs.vllm.ai/en/stable/getting_started
143146

144147
Example:
145148
{"running-requests":10,"waiting-requests":30,"kv-cache-usage":0.4,"loras":[{"running":"lora4,lora2","waiting":"lora3","timestamp":1257894567},{"running":"lora4,lora3","waiting":"","timestamp":1257894569}]}
146-
147149

148150
In addition, as we are using klog, the following parameters are available:
149151
- `add_dir_header`: if true, adds the file directory to the header of the log messages

pkg/common/config.go

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,14 @@ const (
3434
vLLMDefaultPort = 8000
3535
ModeRandom = "random"
3636
ModeEcho = "echo"
37-
dummy = "dummy"
37+
// Failure type constants
38+
FailureTypeRateLimit = "rate_limit"
39+
FailureTypeInvalidAPIKey = "invalid_api_key"
40+
FailureTypeContextLength = "context_length"
41+
FailureTypeServerError = "server_error"
42+
FailureTypeInvalidRequest = "invalid_request"
43+
FailureTypeModelNotFound = "model_not_found"
44+
dummy = "dummy"
3845
)
3946

4047
type Configuration struct {
@@ -150,6 +157,11 @@ type Configuration struct {
150157

151158
// FakeMetrics is a set of metrics to send to Prometheus instead of the real data
152159
FakeMetrics *Metrics `yaml:"fake-metrics" json:"fake-metrics"`
160+
161+
// FailureInjectionRate is the probability (0-100) of injecting failures
162+
FailureInjectionRate int `yaml:"failure-injection-rate" json:"failure-injection-rate"`
163+
// FailureTypes is a list of specific failure types to inject (empty means all types)
164+
FailureTypes []string `yaml:"failure-types" json:"failure-types"`
153165
}
154166

155167
type Metrics struct {
@@ -392,6 +404,27 @@ func (c *Configuration) validate() error {
392404
if c.EventBatchSize < 1 {
393405
return errors.New("event batch size cannot less than 1")
394406
}
407+
408+
if c.FailureInjectionRate < 0 || c.FailureInjectionRate > 100 {
409+
return errors.New("failure injection rate should be between 0 and 100")
410+
}
411+
412+
validFailureTypes := map[string]bool{
413+
FailureTypeRateLimit: true,
414+
FailureTypeInvalidAPIKey: true,
415+
FailureTypeContextLength: true,
416+
FailureTypeServerError: true,
417+
FailureTypeInvalidRequest: true,
418+
FailureTypeModelNotFound: true,
419+
}
420+
for _, failureType := range c.FailureTypes {
421+
if !validFailureTypes[failureType] {
422+
return fmt.Errorf("invalid failure type '%s', valid types are: %s, %s, %s, %s, %s, %s", failureType,
423+
FailureTypeRateLimit, FailureTypeInvalidAPIKey, FailureTypeContextLength,
424+
FailureTypeServerError, FailureTypeInvalidRequest, FailureTypeModelNotFound)
425+
}
426+
}
427+
395428
if c.ZMQMaxConnectAttempts > 10 {
396429
return errors.New("zmq retries times cannot be more than 10")
397430
}
@@ -432,7 +465,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) {
432465
f.IntVar(&config.MaxCPULoras, "max-cpu-loras", config.MaxCPULoras, "Maximum number of LoRAs to store in CPU memory")
433466
f.IntVar(&config.MaxModelLen, "max-model-len", config.MaxModelLen, "Model's context window, maximum number of tokens in a single request including input and output")
434467

435-
f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode, echo - returns the same text that was sent in the request, for chat completion returns the last message, random - returns random sentence from a bank of pre-defined sentences")
468+
f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode: echo - returns the same text that was sent in the request, for chat completion returns the last message; random - returns random sentence from a bank of pre-defined sentences")
436469
f.IntVar(&config.InterTokenLatency, "inter-token-latency", config.InterTokenLatency, "Time to generate one token (in milliseconds)")
437470
f.IntVar(&config.TimeToFirstToken, "time-to-first-token", config.TimeToFirstToken, "Time to first token (in milliseconds)")
438471

@@ -466,6 +499,13 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) {
466499
f.UintVar(&config.ZMQMaxConnectAttempts, "zmq-max-connect-attempts", config.ZMQMaxConnectAttempts, "Maximum number of times to try ZMQ connect")
467500
f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together")
468501

502+
f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures")
503+
504+
failureTypes := getParamValueFromArgs("failure-types")
505+
var dummyFailureTypes multiString
506+
f.Var(&dummyFailureTypes, "failure-types", "List of specific failure types to inject (rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found)")
507+
f.Lookup("failure-types").NoOptDefVal = dummy
508+
469509
// These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help
470510
var dummyString string
471511
f.StringVar(&dummyString, "config", "", "The path to a yaml configuration file. The command line values overwrite the configuration file values")
@@ -505,6 +545,9 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) {
505545
if servedModelNames != nil {
506546
config.ServedModelNames = servedModelNames
507547
}
548+
if failureTypes != nil {
549+
config.FailureTypes = failureTypes
550+
}
508551

509552
if config.HashSeed == "" {
510553
hashSeed := os.Getenv("PYTHONHASHSEED")

pkg/common/config_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,19 @@ var _ = Describe("Simulator configuration", func() {
370370
args: []string{"cmd", "--event-batch-size", "-35",
371371
"--config", "../../manifests/config.yaml"},
372372
},
373+
{
374+
name: "invalid failure injection rate > 100",
375+
args: []string{"cmd", "--model", "test-model", "--failure-injection-rate", "150"},
376+
},
377+
{
378+
name: "invalid failure injection rate < 0",
379+
args: []string{"cmd", "--model", "test-model", "--failure-injection-rate", "-10"},
380+
},
381+
{
382+
name: "invalid failure type",
383+
args: []string{"cmd", "--model", "test-model", "--failure-injection-rate", "50",
384+
"--failure-types", "invalid_type"},
385+
},
373386
{
374387
name: "invalid fake metrics: negative running requests",
375388
args: []string{"cmd", "--fake-metrics", "{\"running-requests\":-10,\"waiting-requests\":30,\"kv-cache-usage\":0.4}",

pkg/common/utils.go

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ const (
3939
RemoteDecodeFinishReason = "remote_decode"
4040
)
4141

42+
// this array defines the probabilities for the buckets to be used for the generation of number of tokens in response
43+
var respLenBucketsProbabilities = [...]float64{0.2, 0.3, 0.2, 0.05, 0.1, 0.15}
44+
var cumulativeBucketsProbabilities []float64
45+
4246
// list of responses to use in random mode for comepltion requests
4347
var chatCompletionFakeResponses = []string{
4448
`Testing@, #testing 1$ ,2%,3^, [4&*5], 6~, 7-_ + (8 : 9) / \ < > .`,
@@ -54,6 +58,16 @@ var chatCompletionFakeResponses = []string{
5458
`Give a man a fish and you feed him for a day; teach a man to fish and you feed him for a lifetime`,
5559
}
5660

61+
func init() {
62+
cumulativeBucketsProbabilities = make([]float64, len(respLenBucketsProbabilities))
63+
sum := 0.0
64+
65+
for i, val := range respLenBucketsProbabilities {
66+
sum += val
67+
cumulativeBucketsProbabilities[i] = sum
68+
}
69+
}
70+
5771
// returns the max tokens or error if incorrect
5872
func GetMaxTokens(maxCompletionTokens *int64, maxTokens *int64) (*int64, error) {
5973
var typeToken string
@@ -154,14 +168,67 @@ func GetRandomResponseText(maxCompletionTokens *int64) (string, string) {
154168
if maxCompletionTokens == nil {
155169
numOfTokens = GetRandomResponseLen()
156170
} else {
157-
numOfTokens = int(*maxCompletionTokens)
158-
finishReason = GetRandomFinishReason()
171+
maxTokens := int(*maxCompletionTokens)
172+
// max tokens is defined - generate real length of the response based on it
173+
numOfTokens = getResponseLengthByHistogram(maxTokens)
174+
if numOfTokens == maxTokens {
175+
// if response should be create with maximum number of tokens - finish reason will be 'length'
176+
finishReason = LengthFinishReason
177+
}
159178
}
160179

161180
text := GetRandomText(numOfTokens)
162181
return text, finishReason
163182
}
164183

184+
// getResponseLengthByHistogram calculates the number of tokens to be returned in a response based on the max tokens value and the pre-defined buckets.
185+
// The response length is distributed according to the probabilities, defined in respLenBucketsProbabilities.
186+
// The histogram contains equally sized buckets and the last special bucket, which contains only the maxTokens value.
187+
// The last element of respLenBucketsProbabilities defines the probability of a reposnse with maxToken tokens.
188+
// Other values define probabilities for the equally sized buckets.
189+
// If maxToken is small (smaller than number of buckets) - the response length is randomly selected from the range [1, maxTokens]
190+
func getResponseLengthByHistogram(maxTokens int) int {
191+
if maxTokens <= 1 {
192+
return maxTokens
193+
}
194+
// maxTokens is small - no need to use the histogram of probabilities, just select a random value in the range [1, maxTokens]
195+
if maxTokens <= len(cumulativeBucketsProbabilities) {
196+
res := RandomInt(1, maxTokens)
197+
return res
198+
}
199+
200+
r := RandomFloat(0, 1)
201+
202+
// check if r is in the last bucket, then maxTokens should be returned
203+
if r > cumulativeBucketsProbabilities[len(cumulativeBucketsProbabilities)-2] {
204+
return maxTokens
205+
}
206+
207+
// determine which bucket to use, the bucket with a cumulative probability larger than r is the bucket to use
208+
// initialize bucketIndex with the last bucket to handle the case (which should not happen) when the probabilities sum is less than 1
209+
bucketIndex := len(cumulativeBucketsProbabilities) - 1
210+
for i, c := range cumulativeBucketsProbabilities {
211+
if r <= c {
212+
bucketIndex = i
213+
break
214+
}
215+
}
216+
217+
// calculate the size of all of the buckets (except the special last bucket)
218+
bucketSize := float64(maxTokens-1) / float64(len(cumulativeBucketsProbabilities)-1)
219+
// start is the minimum number in the required bucket
220+
start := int(bucketSize*float64(bucketIndex)) + 1
221+
// end is the maximum number in the required bucket
222+
end := int(bucketSize * float64(bucketIndex+1))
223+
// sometimes end could be maxTokens because of rounding, change the value to maxToken-1
224+
if end >= maxTokens {
225+
end = maxTokens - 1
226+
}
227+
228+
// pick uniformly within the bucket’s range
229+
return RandomInt(start, end)
230+
}
231+
165232
// GetResponseText returns response text, from a given text
166233
// considering max completion tokens if it is not nil, and a finish reason (stop or length)
167234
func GetResponseText(maxCompletionTokens *int64, text string) (string, string) {

pkg/common/utils_test.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,28 @@ var _ = Describe("Utils", Ordered, func() {
3838
It("should return short text", func() {
3939
maxCompletionTokens := int64(2)
4040
text, finishReason := GetRandomResponseText(&maxCompletionTokens)
41-
Expect(int64(len(Tokenize(text)))).Should(Equal(maxCompletionTokens))
42-
Expect([]string{StopFinishReason, LengthFinishReason}).Should(ContainElement(finishReason))
41+
tokensCnt := int64(len(Tokenize(text)))
42+
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
43+
if tokensCnt == maxCompletionTokens {
44+
Expect(finishReason).To(Equal(LengthFinishReason))
45+
} else {
46+
Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens))
47+
Expect(finishReason).To(Equal(StopFinishReason))
48+
}
4349
})
4450
It("should return long text", func() {
4551
// return required number of tokens although it is higher than ResponseLenMax
4652
maxCompletionTokens := int64(ResponseLenMax * 5)
4753
text, finishReason := GetRandomResponseText(&maxCompletionTokens)
48-
Expect(int64(len(Tokenize(text)))).Should(Equal(maxCompletionTokens))
54+
tokensCnt := int64(len(Tokenize(text)))
55+
Expect(tokensCnt).Should(BeNumerically("<=", maxCompletionTokens))
4956
Expect(IsValidText(text)).To(BeTrue())
50-
Expect([]string{StopFinishReason, LengthFinishReason}).Should(ContainElement(finishReason))
57+
if tokensCnt == maxCompletionTokens {
58+
Expect(finishReason).To(Equal(LengthFinishReason))
59+
} else {
60+
Expect(tokensCnt).To(BeNumerically("<", maxCompletionTokens))
61+
Expect(finishReason).To(Equal(StopFinishReason))
62+
}
5163
})
5264
})
5365

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
Copyright 2025 The llm-d-inference-sim Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package llmdinferencesim
18+
19+
import (
20+
"fmt"
21+
22+
"github.com/llm-d/llm-d-inference-sim/pkg/common"
23+
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
24+
)
25+
26+
const (
27+
// Error message templates
28+
rateLimitMessageTemplate = "Rate limit reached for %s in organization org-xxx on requests per min (RPM): Limit 3, Used 3, Requested 1."
29+
modelNotFoundMessageTemplate = "The model '%s-nonexistent' does not exist"
30+
)
31+
32+
var predefinedFailures = map[string]openaiserverapi.CompletionError{
33+
common.FailureTypeRateLimit: openaiserverapi.NewCompletionError(rateLimitMessageTemplate, 429, nil),
34+
common.FailureTypeInvalidAPIKey: openaiserverapi.NewCompletionError("Incorrect API key provided.", 401, nil),
35+
common.FailureTypeContextLength: openaiserverapi.NewCompletionError(
36+
"This model's maximum context length is 4096 tokens. However, your messages resulted in 4500 tokens.",
37+
400, stringPtr("messages")),
38+
common.FailureTypeServerError: openaiserverapi.NewCompletionError(
39+
"The server is overloaded or not ready yet.", 503, nil),
40+
common.FailureTypeInvalidRequest: openaiserverapi.NewCompletionError(
41+
"Invalid request: missing required parameter 'model'.", 400, stringPtr("model")),
42+
common.FailureTypeModelNotFound: openaiserverapi.NewCompletionError(modelNotFoundMessageTemplate,
43+
404, stringPtr("model")),
44+
}
45+
46+
// shouldInjectFailure determines whether to inject a failure based on configuration
47+
func shouldInjectFailure(config *common.Configuration) bool {
48+
if config.FailureInjectionRate == 0 {
49+
return false
50+
}
51+
52+
return common.RandomInt(1, 100) <= config.FailureInjectionRate
53+
}
54+
55+
// getRandomFailure returns a random failure from configured types or all types if none specified
56+
func getRandomFailure(config *common.Configuration) openaiserverapi.CompletionError {
57+
var availableFailures []string
58+
if len(config.FailureTypes) == 0 {
59+
// Use all failure types if none specified
60+
for failureType := range predefinedFailures {
61+
availableFailures = append(availableFailures, failureType)
62+
}
63+
} else {
64+
availableFailures = config.FailureTypes
65+
}
66+
67+
if len(availableFailures) == 0 {
68+
// Fallback to server_error if no valid types
69+
return predefinedFailures[common.FailureTypeServerError]
70+
}
71+
72+
randomIndex := common.RandomInt(0, len(availableFailures)-1)
73+
randomType := availableFailures[randomIndex]
74+
75+
// Customize message with current model name
76+
failure := predefinedFailures[randomType]
77+
if randomType == common.FailureTypeRateLimit && config.Model != "" {
78+
failure.Message = fmt.Sprintf(rateLimitMessageTemplate, config.Model)
79+
} else if randomType == common.FailureTypeModelNotFound && config.Model != "" {
80+
failure.Message = fmt.Sprintf(modelNotFoundMessageTemplate, config.Model)
81+
}
82+
83+
return failure
84+
}
85+
86+
func stringPtr(s string) *string {
87+
return &s
88+
}

0 commit comments

Comments
 (0)