diff --git a/README.md b/README.md index f274b2b1..8b1bd80c 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,8 @@ For more details see the 100 { + return errors.New("failure injection rate should be between 0 and 100") + } + + validFailureTypes := map[string]bool{ + FailureTypeRateLimit: true, + FailureTypeInvalidAPIKey: true, + FailureTypeContextLength: true, + FailureTypeServerError: true, + FailureTypeInvalidRequest: true, + FailureTypeModelNotFound: true, + } + for _, failureType := range c.FailureTypes { + if !validFailureTypes[failureType] { + return fmt.Errorf("invalid failure type '%s', valid types are: %s, %s, %s, %s, %s, %s", failureType, + FailureTypeRateLimit, FailureTypeInvalidAPIKey, FailureTypeContextLength, + FailureTypeServerError, FailureTypeInvalidRequest, FailureTypeModelNotFound) + } + } + if c.ZMQMaxConnectAttempts > 10 { return errors.New("zmq retries times cannot be more than 10") } @@ -397,7 +430,7 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.IntVar(&config.MaxCPULoras, "max-cpu-loras", config.MaxCPULoras, "Maximum number of LoRAs to store in CPU memory") f.IntVar(&config.MaxModelLen, "max-model-len", config.MaxModelLen, "Model's context window, maximum number of tokens in a single request including input and output") - f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode, echo - returns the same text that was sent in the request, for chat completion returns the last message, random - returns random sentence from a bank of pre-defined sentences") + f.StringVar(&config.Mode, "mode", config.Mode, "Simulator mode: echo - returns the same text that was sent in the request, for chat completion returns the last message; random - returns random sentence from a bank of pre-defined sentences") f.IntVar(&config.InterTokenLatency, "inter-token-latency", config.InterTokenLatency, "Time to generate one token (in milliseconds)") f.IntVar(&config.TimeToFirstToken, "time-to-first-token", config.TimeToFirstToken, "Time to first token (in milliseconds)") f.IntVar(&config.KVCacheTransferLatency, "kv-cache-transfer-latency", config.KVCacheTransferLatency, "Time for KV-cache transfer from a remote vLLM (in milliseconds)") @@ -424,6 +457,13 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { f.UintVar(&config.ZMQMaxConnectAttempts, "zmq-max-connect-attempts", config.ZMQMaxConnectAttempts, "Maximum number of times to try ZMQ connect") f.IntVar(&config.EventBatchSize, "event-batch-size", config.EventBatchSize, "Maximum number of kv-cache events to be sent together") + f.IntVar(&config.FailureInjectionRate, "failure-injection-rate", config.FailureInjectionRate, "Probability (0-100) of injecting failures") + + failureTypes := getParamValueFromArgs("failure-types") + var dummyFailureTypes multiString + f.Var(&dummyFailureTypes, "failure-types", "List of specific failure types to inject (rate_limit, invalid_api_key, context_length, server_error, invalid_request, model_not_found)") + f.Lookup("failure-types").NoOptDefVal = dummy + // These values were manually parsed above in getParamValueFromArgs, we leave this in order to get these flags in --help var dummyString string f.StringVar(&dummyString, "config", "", "The path to a yaml configuration file. The command line values overwrite the configuration file values") @@ -463,6 +503,9 @@ func ParseCommandParamsAndLoadConfig() (*Configuration, error) { if servedModelNames != nil { config.ServedModelNames = servedModelNames } + if failureTypes != nil { + config.FailureTypes = failureTypes + } if config.HashSeed == "" { hashSeed := os.Getenv("PYTHONHASHSEED") diff --git a/pkg/common/config_test.go b/pkg/common/config_test.go index f50c40a9..770716a6 100644 --- a/pkg/common/config_test.go +++ b/pkg/common/config_test.go @@ -370,6 +370,19 @@ var _ = Describe("Simulator configuration", func() { args: []string{"cmd", "--event-batch-size", "-35", "--config", "../../manifests/config.yaml"}, }, + { + name: "invalid failure injection rate > 100", + args: []string{"cmd", "--model", "test-model", "--failure-injection-rate", "150"}, + }, + { + name: "invalid failure injection rate < 0", + args: []string{"cmd", "--model", "test-model", "--failure-injection-rate", "-10"}, + }, + { + name: "invalid failure type", + args: []string{"cmd", "--model", "test-model", "--failure-injection-rate", "50", + "--failure-types", "invalid_type"}, + }, { name: "invalid fake metrics: negative running requests", args: []string{"cmd", "--fake-metrics", "{\"running-requests\":-10,\"waiting-requests\":30,\"kv-cache-usage\":0.4}", diff --git a/pkg/llm-d-inference-sim/failures.go b/pkg/llm-d-inference-sim/failures.go new file mode 100644 index 00000000..69daf36e --- /dev/null +++ b/pkg/llm-d-inference-sim/failures.go @@ -0,0 +1,88 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llmdinferencesim + +import ( + "fmt" + + "github.com/llm-d/llm-d-inference-sim/pkg/common" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" +) + +const ( + // Error message templates + rateLimitMessageTemplate = "Rate limit reached for %s in organization org-xxx on requests per min (RPM): Limit 3, Used 3, Requested 1." + modelNotFoundMessageTemplate = "The model '%s-nonexistent' does not exist" +) + +var predefinedFailures = map[string]openaiserverapi.CompletionError{ + common.FailureTypeRateLimit: openaiserverapi.NewCompletionError(rateLimitMessageTemplate, 429, nil), + common.FailureTypeInvalidAPIKey: openaiserverapi.NewCompletionError("Incorrect API key provided.", 401, nil), + common.FailureTypeContextLength: openaiserverapi.NewCompletionError( + "This model's maximum context length is 4096 tokens. However, your messages resulted in 4500 tokens.", + 400, stringPtr("messages")), + common.FailureTypeServerError: openaiserverapi.NewCompletionError( + "The server is overloaded or not ready yet.", 503, nil), + common.FailureTypeInvalidRequest: openaiserverapi.NewCompletionError( + "Invalid request: missing required parameter 'model'.", 400, stringPtr("model")), + common.FailureTypeModelNotFound: openaiserverapi.NewCompletionError(modelNotFoundMessageTemplate, + 404, stringPtr("model")), +} + +// shouldInjectFailure determines whether to inject a failure based on configuration +func shouldInjectFailure(config *common.Configuration) bool { + if config.FailureInjectionRate == 0 { + return false + } + + return common.RandomInt(1, 100) <= config.FailureInjectionRate +} + +// getRandomFailure returns a random failure from configured types or all types if none specified +func getRandomFailure(config *common.Configuration) openaiserverapi.CompletionError { + var availableFailures []string + if len(config.FailureTypes) == 0 { + // Use all failure types if none specified + for failureType := range predefinedFailures { + availableFailures = append(availableFailures, failureType) + } + } else { + availableFailures = config.FailureTypes + } + + if len(availableFailures) == 0 { + // Fallback to server_error if no valid types + return predefinedFailures[common.FailureTypeServerError] + } + + randomIndex := common.RandomInt(0, len(availableFailures)-1) + randomType := availableFailures[randomIndex] + + // Customize message with current model name + failure := predefinedFailures[randomType] + if randomType == common.FailureTypeRateLimit && config.Model != "" { + failure.Message = fmt.Sprintf(rateLimitMessageTemplate, config.Model) + } else if randomType == common.FailureTypeModelNotFound && config.Model != "" { + failure.Message = fmt.Sprintf(modelNotFoundMessageTemplate, config.Model) + } + + return failure +} + +func stringPtr(s string) *string { + return &s +} diff --git a/pkg/llm-d-inference-sim/failures_test.go b/pkg/llm-d-inference-sim/failures_test.go new file mode 100644 index 00000000..5ff48034 --- /dev/null +++ b/pkg/llm-d-inference-sim/failures_test.go @@ -0,0 +1,334 @@ +/* +Copyright 2025 The llm-d-inference-sim Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package llmdinferencesim + +import ( + "context" + "errors" + "net/http" + "strings" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" + + "github.com/llm-d/llm-d-inference-sim/pkg/common" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" +) + +var _ = Describe("Failures", func() { + Describe("getRandomFailure", Ordered, func() { + BeforeAll(func() { + common.InitRandom(time.Now().UnixNano()) + }) + + It("should return a failure from all types when none specified", func() { + config := &common.Configuration{ + Model: "test-model", + FailureTypes: []string{}, + } + failure := getRandomFailure(config) + Expect(failure.Code).To(BeNumerically(">=", 400)) + Expect(failure.Message).ToNot(BeEmpty()) + Expect(failure.Type).ToNot(BeEmpty()) + }) + + It("should return rate limit failure when specified", func() { + config := &common.Configuration{ + Model: "test-model", + FailureTypes: []string{common.FailureTypeRateLimit}, + } + failure := getRandomFailure(config) + Expect(failure.Code).To(Equal(429)) + Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(429))) + Expect(strings.Contains(failure.Message, "test-model")).To(BeTrue()) + }) + + It("should return invalid API key failure when specified", func() { + config := &common.Configuration{ + FailureTypes: []string{common.FailureTypeInvalidAPIKey}, + } + failure := getRandomFailure(config) + Expect(failure.Code).To(Equal(401)) + Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(401))) + Expect(failure.Message).To(Equal("Incorrect API key provided.")) + }) + + It("should return context length failure when specified", func() { + config := &common.Configuration{ + FailureTypes: []string{common.FailureTypeContextLength}, + } + failure := getRandomFailure(config) + Expect(failure.Code).To(Equal(400)) + Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(400))) + Expect(failure.Param).ToNot(BeNil()) + Expect(*failure.Param).To(Equal("messages")) + }) + + It("should return server error when specified", func() { + config := &common.Configuration{ + FailureTypes: []string{common.FailureTypeServerError}, + } + failure := getRandomFailure(config) + Expect(failure.Code).To(Equal(503)) + Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(503))) + }) + + It("should return model not found failure when specified", func() { + config := &common.Configuration{ + Model: "test-model", + FailureTypes: []string{common.FailureTypeModelNotFound}, + } + failure := getRandomFailure(config) + Expect(failure.Code).To(Equal(404)) + Expect(failure.Type).To(Equal(openaiserverapi.ErrorCodeToType(404))) + Expect(strings.Contains(failure.Message, "test-model-nonexistent")).To(BeTrue()) + }) + + It("should return server error as fallback for empty types", func() { + config := &common.Configuration{ + FailureTypes: []string{}, + } + // This test is probabilistic since it randomly selects, but we can test structure + failure := getRandomFailure(config) + Expect(failure.Code).To(BeNumerically(">=", 400)) + Expect(failure.Type).ToNot(BeEmpty()) + }) + }) + Describe("Simulator with failure injection", func() { + var ( + client *http.Client + ctx context.Context + ) + + AfterEach(func() { + if ctx != nil { + ctx.Done() + } + }) + + Context("with 100% failure injection rate", func() { + BeforeEach(func() { + ctx = context.Background() + var err error + client, err = startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--failure-injection-rate", "100", + }, nil) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should always return an error response for chat completions", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + _, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + Expect(openaiError.StatusCode).To(BeNumerically(">=", 400)) + Expect(openaiError.Type).ToNot(BeEmpty()) + Expect(openaiError.Message).ToNot(BeEmpty()) + }) + + It("should always return an error response for text completions", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + _, err := openaiClient.Completions.New(ctx, openai.CompletionNewParams{ + Model: openai.CompletionNewParamsModel(model), + Prompt: openai.CompletionNewParamsPromptUnion{ + OfString: openai.String(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + Expect(openaiError.StatusCode).To(BeNumerically(">=", 400)) + Expect(openaiError.Type).ToNot(BeEmpty()) + Expect(openaiError.Message).ToNot(BeEmpty()) + }) + }) + + Context("with specific failure types", func() { + BeforeEach(func() { + ctx = context.Background() + var err error + client, err = startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--failure-injection-rate", "100", + "--failure-types", common.FailureTypeRateLimit, + }, nil) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should return only rate limit errors", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + _, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + Expect(openaiError.StatusCode).To(Equal(429)) + Expect(openaiError.Type).To(Equal(openaiserverapi.ErrorCodeToType(429))) + Expect(strings.Contains(openaiError.Message, model)).To(BeTrue()) + }) + }) + + Context("with multiple specific failure types", func() { + BeforeEach(func() { + ctx = context.Background() + var err error + client, err = startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--failure-injection-rate", "100", + "--failure-types", common.FailureTypeInvalidAPIKey, common.FailureTypeServerError, + }, nil) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should return only specified error types", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + // Make multiple requests to verify we get the expected error types + for i := 0; i < 10; i++ { + _, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + + // Should only be one of the specified types + Expect(openaiError.StatusCode == 401 || openaiError.StatusCode == 503).To(BeTrue()) + Expect(openaiError.Type == openaiserverapi.ErrorCodeToType(401) || + openaiError.Type == openaiserverapi.ErrorCodeToType(503)).To(BeTrue()) + } + }) + }) + + Context("with 0% failure injection rate", func() { + BeforeEach(func() { + ctx = context.Background() + var err error + client, err = startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--failure-injection-rate", "0", + }, nil) + Expect(err).ToNot(HaveOccurred()) + }) + + It("should never return errors and behave like random mode", func() { + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + resp, err := openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).ToNot(HaveOccurred()) + Expect(resp.Choices).To(HaveLen(1)) + Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) + Expect(resp.Model).To(Equal(model)) + }) + }) + + Context("testing all predefined failure types", func() { + DescribeTable("should return correct error for each failure type", + func(failureType string, expectedStatusCode int, expectedErrorType string) { + ctx := context.Background() + client, err := startServerWithArgs(ctx, "failure", []string{ + "cmd", "--model", model, + "--failure-injection-rate", "100", + "--failure-types", failureType, + }, nil) + Expect(err).ToNot(HaveOccurred()) + + openaiClient := openai.NewClient( + option.WithBaseURL(baseURL), + option.WithHTTPClient(client), + ) + + _, err = openaiClient.Chat.Completions.New(ctx, openai.ChatCompletionNewParams{ + Model: model, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage(userMessage), + }, + }) + + Expect(err).To(HaveOccurred()) + + var openaiError *openai.Error + ok := errors.As(err, &openaiError) + Expect(ok).To(BeTrue()) + Expect(openaiError.StatusCode).To(Equal(expectedStatusCode)) + Expect(openaiError.Type).To(Equal(expectedErrorType)) + // Note: OpenAI Go client doesn't directly expose the error code field, + // but we can verify via status code and type + }, + Entry("rate_limit", common.FailureTypeRateLimit, 429, openaiserverapi.ErrorCodeToType(429)), + Entry("invalid_api_key", common.FailureTypeInvalidAPIKey, 401, openaiserverapi.ErrorCodeToType(401)), + Entry("context_length", common.FailureTypeContextLength, 400, openaiserverapi.ErrorCodeToType(400)), + Entry("server_error", common.FailureTypeServerError, 503, openaiserverapi.ErrorCodeToType(503)), + Entry("invalid_request", common.FailureTypeInvalidRequest, 400, openaiserverapi.ErrorCodeToType(400)), + Entry("model_not_found", common.FailureTypeModelNotFound, 404, openaiserverapi.ErrorCodeToType(404)), + ) + }) + }) +}) diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 9f56f798..323a0162 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -294,20 +294,20 @@ func (s *VllmSimulator) HandleUnloadLora(ctx *fasthttp.RequestCtx) { s.unloadLora(ctx) } -func (s *VllmSimulator) validateRequest(req openaiserverapi.CompletionRequest) (string, string, int) { +func (s *VllmSimulator) validateRequest(req openaiserverapi.CompletionRequest) (string, int) { if !s.isValidModel(req.GetModel()) { - return fmt.Sprintf("The model `%s` does not exist.", req.GetModel()), "NotFoundError", fasthttp.StatusNotFound + return fmt.Sprintf("The model `%s` does not exist.", req.GetModel()), fasthttp.StatusNotFound } if req.GetMaxCompletionTokens() != nil && *req.GetMaxCompletionTokens() <= 0 { - return "Max completion tokens and max tokens should be positive", "Invalid request", fasthttp.StatusBadRequest + return "Max completion tokens and max tokens should be positive", fasthttp.StatusBadRequest } if req.IsDoRemoteDecode() && req.IsStream() { - return "Prefill does not support streaming", "Invalid request", fasthttp.StatusBadRequest + return "Prefill does not support streaming", fasthttp.StatusBadRequest } - return "", "", fasthttp.StatusOK + return "", fasthttp.StatusOK } // isValidModel checks if the given model is the base model or one of "loaded" LoRAs @@ -339,6 +339,13 @@ func (s *VllmSimulator) isLora(model string) bool { // handleCompletions general completion requests handler, support both text and chat completion APIs func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatCompletion bool) { + // Check if we should inject a failure + if shouldInjectFailure(s.config) { + failure := getRandomFailure(s.config) + s.sendCompletionError(ctx, failure, true) + return + } + vllmReq, err := s.readRequest(ctx, isChatCompletion) if err != nil { s.logger.Error(err, "failed to read and parse request body") @@ -346,9 +353,9 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple return } - errMsg, errType, errCode := s.validateRequest(vllmReq) + errMsg, errCode := s.validateRequest(vllmReq) if errMsg != "" { - s.sendCompletionError(ctx, errMsg, errType, errCode) + s.sendCompletionError(ctx, openaiserverapi.NewCompletionError(errMsg, errCode, nil), false) return } @@ -375,8 +382,9 @@ func (s *VllmSimulator) handleCompletions(ctx *fasthttp.RequestCtx, isChatComple completionTokens := vllmReq.GetMaxCompletionTokens() isValid, actualCompletionTokens, totalTokens := common.ValidateContextWindow(promptTokens, completionTokens, s.config.MaxModelLen) if !isValid { - s.sendCompletionError(ctx, fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", - s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens), "BadRequestError", fasthttp.StatusBadRequest) + message := fmt.Sprintf("This model's maximum context length is %d tokens. However, you requested %d tokens (%d in the messages, %d in the completion). Please reduce the length of the messages or completion", + s.config.MaxModelLen, totalTokens, promptTokens, actualCompletionTokens) + s.sendCompletionError(ctx, openaiserverapi.NewCompletionError(message, fasthttp.StatusBadRequest, nil), false) return } @@ -528,22 +536,25 @@ func (s *VllmSimulator) responseSentCallback(model string) { } // sendCompletionError sends an error response for the current completion request -func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, msg string, errType string, code int) { - compErr := openaiserverapi.CompletionError{ - Object: "error", - Message: msg, - Type: errType, - Code: code, - Param: nil, +// isInjected indicates if this is an injected failure for logging purposes +func (s *VllmSimulator) sendCompletionError(ctx *fasthttp.RequestCtx, + compErr openaiserverapi.CompletionError, isInjected bool) { + if isInjected { + s.logger.Info("Injecting failure", "type", compErr.Type, "message", compErr.Message) + } else { + s.logger.Error(nil, compErr.Message) + } + + errorResp := openaiserverapi.ErrorResponse{ + Error: compErr, } - s.logger.Error(nil, compErr.Message) - data, err := json.Marshal(compErr) + data, err := json.Marshal(errorResp) if err != nil { ctx.Error(err.Error(), fasthttp.StatusInternalServerError) } else { ctx.SetContentType("application/json") - ctx.SetStatusCode(code) + ctx.SetStatusCode(compErr.Code) ctx.SetBody(data) } } diff --git a/pkg/llm-d-inference-sim/simulator_test.go b/pkg/llm-d-inference-sim/simulator_test.go index 2641e5b9..9e4c882b 100644 --- a/pkg/llm-d-inference-sim/simulator_test.go +++ b/pkg/llm-d-inference-sim/simulator_test.go @@ -829,5 +829,4 @@ var _ = Describe("Simulator", func() { Entry(nil, 10000, 0, 1000, 0, false), ) }) - }) diff --git a/pkg/openai-server-api/response.go b/pkg/openai-server-api/response.go index a8f4a652..d32784e3 100644 --- a/pkg/openai-server-api/response.go +++ b/pkg/openai-server-api/response.go @@ -21,6 +21,8 @@ import ( "encoding/json" "errors" "strings" + + "github.com/valyala/fasthttp" ) // CompletionResponse interface representing both completion response types (text and chat) @@ -208,14 +210,53 @@ type ChatRespChunkChoice struct { // CompletionError defines the simulator's response in case of an error type CompletionError struct { - // Object is a type of this Object, "error" - Object string `json:"object"` // Message is an error Message Message string `json:"message"` // Type is a type of the error Type string `json:"type"` - // Params is the error's parameters + // Param is the error's parameter Param *string `json:"param"` - // Code is http status Code - Code int `json:"code"` + // Code is the error code integer (same as HTTP status code) + Code int `json:"code,omitempty"` +} + +// NewCompletionError creates a new CompletionError +func NewCompletionError(message string, code int, param *string) CompletionError { + return CompletionError{ + Message: message, + Code: code, + Type: ErrorCodeToType(code), + Param: param, + } +} + +// ErrorResponse wraps the error in the expected OpenAI format +type ErrorResponse struct { + Error CompletionError `json:"error"` +} + +// ErrorCodeToType maps error code to error type according to https://www.npmjs.com/package/openai +func ErrorCodeToType(code int) string { + errorType := "" + switch code { + case fasthttp.StatusBadRequest: + errorType = "BadRequestError" + case fasthttp.StatusUnauthorized: + errorType = "AuthenticationError" + case fasthttp.StatusForbidden: + errorType = "PermissionDeniedError" + case fasthttp.StatusNotFound: + errorType = "NotFoundError" + case fasthttp.StatusUnprocessableEntity: + errorType = "UnprocessableEntityError" + case fasthttp.StatusTooManyRequests: + errorType = "RateLimitError" + default: + if code >= fasthttp.StatusInternalServerError { + errorType = "InternalServerError" + } else { + errorType = "APIConnectionError" + } + } + return errorType }