From b8438ecac394670eba7767bbbbc5f810cef0cb7e Mon Sep 17 00:00:00 2001 From: irar2 Date: Thu, 23 Oct 2025 13:06:48 +0300 Subject: [PATCH] Change packages' dependencies Signed-off-by: irar2 --- .../tools_utils.go | 68 ++++++++--------- pkg/dataset/custom_dataset.go | 2 +- pkg/dataset/dataset.go | 34 +-------- pkg/dataset/dataset_test.go | 4 +- pkg/llm-d-inference-sim/helpers.go | 7 ++ pkg/llm-d-inference-sim/server.go | 2 +- pkg/llm-d-inference-sim/simulator.go | 26 +++---- pkg/llm-d-inference-sim/streaming.go | 60 +++++---------- pkg/llm-d-inference-sim/worker.go | 8 +- pkg/openai-server-api/request.go | 74 +++++++++++++------ pkg/openai-server-api/response.go | 52 ++++++++++--- 11 files changed, 169 insertions(+), 168 deletions(-) rename pkg/{openai-server-api => common}/tools_utils.go (83%) diff --git a/pkg/openai-server-api/tools_utils.go b/pkg/common/tools_utils.go similarity index 83% rename from pkg/openai-server-api/tools_utils.go rename to pkg/common/tools_utils.go index 58f3a0df..8df7ebd2 100644 --- a/pkg/openai-server-api/tools_utils.go +++ b/pkg/common/tools_utils.go @@ -14,13 +14,13 @@ See the License for the specific language governing permissions and limitations under the License. */ -package openaiserverapi +package common import ( "encoding/json" "fmt" - "github.com/llm-d/llm-d-inference-sim/pkg/common" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" "github.com/santhosh-tekuri/jsonschema/v5" ) @@ -30,7 +30,7 @@ const ( ToolChoiceRequired = "required" ) -func CountTokensForToolCalls(toolCalls []ToolCall) int { +func CountTokensForToolCalls(toolCalls []openaiserverapi.ToolCall) int { numberOfTokens := 0 for _, tc := range toolCalls { // 3 - name, id, and type @@ -55,7 +55,7 @@ var fakeStringArguments = []string{ // CreateToolCalls creates and returns response payload based on this request // (tool calls or nothing in case we randomly choose not to generate calls), // and the number of generated completion token sand the finish reason -func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configuration) ([]ToolCall, int, error) { +func CreateToolCalls(tools []openaiserverapi.Tool, toolChoice string, config *Configuration) ([]openaiserverapi.ToolCall, int, error) { // This function is called if tool choice is either 'required' or 'auto'. // In case of 'required' at least one tool call has to be created, and we randomly choose // the number of calls starting from one. Otherwise, we start from 0, and in case we randomly @@ -64,16 +64,16 @@ func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configurati if toolChoice == ToolChoiceRequired { min = 1 } - numberOfCalls := common.RandomInt(min, len(tools)) + numberOfCalls := RandomInt(min, len(tools)) if numberOfCalls == 0 { return nil, 0, nil } - calls := make([]ToolCall, 0) + calls := make([]openaiserverapi.ToolCall, 0) for i := range numberOfCalls { // Randomly choose which tools to call. We may call the same tool more than once. - index := common.RandomInt(0, len(tools)-1) - args, err := GenerateToolArguments(tools[index], config) + index := RandomInt(0, len(tools)-1) + args, err := generateToolArguments(tools[index], config) if err != nil { return nil, 0, err } @@ -82,13 +82,13 @@ func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configurati return nil, 0, err } - call := ToolCall{ - Function: FunctionCall{ + call := openaiserverapi.ToolCall{ + Function: openaiserverapi.FunctionCall{ Arguments: string(argsJson), - TokenizedArguments: common.Tokenize(string(argsJson)), + TokenizedArguments: Tokenize(string(argsJson)), Name: &tools[index].Function.Name, }, - ID: "chatcmpl-tool-" + common.RandomNumericString(10), + ID: "chatcmpl-tool-" + RandomNumericString(10), Type: "function", Index: i, } @@ -98,7 +98,7 @@ func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configurati return calls, CountTokensForToolCalls(calls), nil } -func GetRequiredAsMap(property map[string]any) map[string]struct{} { +func getRequiredAsMap(property map[string]any) map[string]struct{} { required := make(map[string]struct{}) requiredParams, ok := property["required"] if ok { @@ -111,18 +111,18 @@ func GetRequiredAsMap(property map[string]any) map[string]struct{} { return required } -func GenerateToolArguments(tool Tool, config *common.Configuration) (map[string]any, error) { +func generateToolArguments(tool openaiserverapi.Tool, config *Configuration) (map[string]any, error) { arguments := make(map[string]any) properties, _ := tool.Function.Parameters["properties"].(map[string]any) - required := GetRequiredAsMap(tool.Function.Parameters) + required := getRequiredAsMap(tool.Function.Parameters) for param, property := range properties { _, paramIsRequired := required[param] - if !paramIsRequired && !common.RandomBool(config.ToolCallNotRequiredParamProbability) { + if !paramIsRequired && !RandomBool(config.ToolCallNotRequiredParamProbability) { continue } - arg, err := CreateArgument(property, config) + arg, err := createArgument(property, config) if err != nil { return nil, err } @@ -132,7 +132,7 @@ func GenerateToolArguments(tool Tool, config *common.Configuration) (map[string] return arguments, nil } -func CreateArgument(property any, config *common.Configuration) (any, error) { +func createArgument(property any, config *Configuration) (any, error) { propertyMap, _ := property.(map[string]any) paramType := propertyMap["type"] @@ -141,20 +141,20 @@ func CreateArgument(property any, config *common.Configuration) (any, error) { if ok { enumArray, ok := enum.([]any) if ok && len(enumArray) > 0 { - index := common.RandomInt(0, len(enumArray)-1) + index := RandomInt(0, len(enumArray)-1) return enumArray[index], nil } } switch paramType { case "string": - return GetStringArgument(), nil + return getStringArgument(), nil case "integer": - return common.RandomInt(config.MinToolCallIntegerParam, config.MaxToolCallIntegerParam), nil + return RandomInt(config.MinToolCallIntegerParam, config.MaxToolCallIntegerParam), nil case "number": - return common.RandomFloat(config.MinToolCallNumberParam, config.MaxToolCallNumberParam), nil + return RandomFloat(config.MinToolCallNumberParam, config.MaxToolCallNumberParam), nil case "boolean": - return common.FlipCoin(), nil + return FlipCoin(), nil case "array": items := propertyMap["items"] itemsMap := items.(map[string]any) @@ -169,10 +169,10 @@ func CreateArgument(property any, config *common.Configuration) (any, error) { if minItems > maxItems { return nil, fmt.Errorf("minItems (%d) is greater than maxItems(%d)", minItems, maxItems) } - numberOfElements := common.RandomInt(minItems, maxItems) + numberOfElements := RandomInt(minItems, maxItems) array := make([]any, numberOfElements) for i := range numberOfElements { - elem, err := CreateArgument(itemsMap, config) + elem, err := createArgument(itemsMap, config) if err != nil { return nil, err } @@ -180,15 +180,15 @@ func CreateArgument(property any, config *common.Configuration) (any, error) { } return array, nil case "object": - required := GetRequiredAsMap(propertyMap) + required := getRequiredAsMap(propertyMap) objectProperties := propertyMap["properties"].(map[string]any) object := make(map[string]interface{}) for fieldName, fieldProperties := range objectProperties { _, fieldIsRequired := required[fieldName] - if !fieldIsRequired && !common.RandomBool(config.ObjectToolCallNotRequiredParamProbability) { + if !fieldIsRequired && !RandomBool(config.ObjectToolCallNotRequiredParamProbability) { continue } - fieldValue, err := CreateArgument(fieldProperties, config) + fieldValue, err := createArgument(fieldProperties, config) if err != nil { return nil, err } @@ -200,24 +200,24 @@ func CreateArgument(property any, config *common.Configuration) (any, error) { } } -func GetStringArgument() string { - index := common.RandomInt(0, len(fakeStringArguments)-1) +func getStringArgument() string { + index := RandomInt(0, len(fakeStringArguments)-1) return fakeStringArguments[index] } -type Validator struct { +type ToolsValidator struct { schema *jsonschema.Schema } -func CreateValidator() (*Validator, error) { +func CreateToolsValidator() (*ToolsValidator, error) { sch, err := jsonschema.CompileString("schema.json", schema) if err != nil { return nil, err } - return &Validator{schema: sch}, nil + return &ToolsValidator{schema: sch}, nil } -func (v *Validator) ValidateTool(tool []byte) error { +func (v *ToolsValidator) ValidateTool(tool []byte) error { var value interface{} if err := json.Unmarshal(tool, &value); err != nil { return err diff --git a/pkg/dataset/custom_dataset.go b/pkg/dataset/custom_dataset.go index ecf7659a..69b2f591 100644 --- a/pkg/dataset/custom_dataset.go +++ b/pkg/dataset/custom_dataset.go @@ -439,7 +439,7 @@ func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode st if mode == common.ModeEcho { return d.echo(req) } - nTokensToGen, finishReason := howManyTokensToGen(d.extractMaxTokens(req), req.GetIgnoreEOS()) + nTokensToGen, finishReason := howManyTokensToGen(req.ExtractMaxTokens(), req.GetIgnoreEOS()) tokens, err := d.GenerateTokens(req, nTokensToGen, finishReason) return tokens, finishReason, err } diff --git a/pkg/dataset/dataset.go b/pkg/dataset/dataset.go index 41cb8354..5d5d2fe9 100644 --- a/pkg/dataset/dataset.go +++ b/pkg/dataset/dataset.go @@ -18,7 +18,6 @@ package dataset import ( "context" - "errors" "math" "math/rand" @@ -291,12 +290,7 @@ func (d *BaseDataset) Close() error { } func (d *BaseDataset) echo(req openaiserverapi.CompletionRequest) ([]string, string, error) { - nMaxTokens := d.extractMaxTokens(req) - prompt, err := d.extractPrompt(req) - if err != nil { - return nil, "", err - } - tokens, finishReason := EchoResponseTokens(nMaxTokens, prompt) + tokens, finishReason := EchoResponseTokens(req.ExtractMaxTokens(), req.ExtractPrompt()) return tokens, finishReason, nil } @@ -305,30 +299,6 @@ func (d *BaseDataset) GetTokens(req openaiserverapi.CompletionRequest, mode stri if mode == common.ModeEcho { return d.echo(req) } - nTokensToGen, finishReason := howManyTokensToGen(d.extractMaxTokens(req), req.GetIgnoreEOS()) + nTokensToGen, finishReason := howManyTokensToGen(req.ExtractMaxTokens(), req.GetIgnoreEOS()) return GenPresetRandomTokens(nTokensToGen), finishReason, nil } - -// extractMaxTokens extracts the max tokens from the request -// for chat completion - max_completion_tokens field is used -// for text completion - max_tokens field is used -func (d *BaseDataset) extractMaxTokens(req openaiserverapi.CompletionRequest) *int64 { - if chatReq, ok := req.(*openaiserverapi.ChatCompletionRequest); ok { - return chatReq.GetMaxCompletionTokens() - } else if textReq, ok := req.(*openaiserverapi.TextCompletionRequest); ok { - return textReq.MaxTokens - } - return nil -} - -// extractPrompt extracts the prompt from the request -// for chat completion - the last user message is used as the prompt -// for text completion - the prompt field is used -func (d *BaseDataset) extractPrompt(req openaiserverapi.CompletionRequest) (string, error) { - if chatReq, ok := req.(*openaiserverapi.ChatCompletionRequest); ok { - return chatReq.GetLastUserMsg(), nil - } else if textReq, ok := req.(*openaiserverapi.TextCompletionRequest); ok { - return textReq.GetPrompt(), nil - } - return "", errors.New("unknown request type") -} diff --git a/pkg/dataset/dataset_test.go b/pkg/dataset/dataset_test.go index 83a2953b..1321e9e9 100644 --- a/pkg/dataset/dataset_test.go +++ b/pkg/dataset/dataset_test.go @@ -92,11 +92,9 @@ var _ = Describe("Dataset", Ordered, func() { func(maxCompletionTokens int) { n := int64(maxCompletionTokens) req := &openaiserverapi.ChatCompletionRequest{ - BaseCompletionRequest: openaiserverapi.BaseCompletionRequest{ - IgnoreEOS: true, - }, MaxTokens: &n, } + req.SetIgnoreEOS(true) tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom) Expect(err).ShouldNot(HaveOccurred()) nGenTokens := int64(len(tokens)) diff --git a/pkg/llm-d-inference-sim/helpers.go b/pkg/llm-d-inference-sim/helpers.go index 60089da7..451767cb 100644 --- a/pkg/llm-d-inference-sim/helpers.go +++ b/pkg/llm-d-inference-sim/helpers.go @@ -20,6 +20,9 @@ package llmdinferencesim import ( "encoding/json" "fmt" + + "github.com/llm-d/llm-d-inference-sim/pkg/common" + openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api" ) // isValidModel checks if the given model is the base model or one of "loaded" LoRAs @@ -92,3 +95,7 @@ func (s *VllmSimulator) showConfig(dp bool) error { s.logger.Info("Configuration:", "", string(cfgJSON)) return nil } + +func (s *VllmSimulator) getNumberOfPromptTokens(req openaiserverapi.CompletionRequest) int { + return len(common.Tokenize(req.GetPrompt())) +} diff --git a/pkg/llm-d-inference-sim/server.go b/pkg/llm-d-inference-sim/server.go index 6fba147d..6fb958c1 100644 --- a/pkg/llm-d-inference-sim/server.go +++ b/pkg/llm-d-inference-sim/server.go @@ -233,7 +233,7 @@ func (s *VllmSimulator) validateRequest(req openaiserverapi.CompletionRequest) ( } // Validate context window constraints - promptTokens := req.GetNumberOfPromptTokens() + promptTokens := s.getNumberOfPromptTokens(req) completionTokens := req.GetMaxCompletionTokens() isValid, actualCompletionTokens, totalTokens := common.ValidateContextWindow(promptTokens, completionTokens, s.config.MaxModelLen) if !isValid { diff --git a/pkg/llm-d-inference-sim/simulator.go b/pkg/llm-d-inference-sim/simulator.go index 76405bb8..09dcda39 100644 --- a/pkg/llm-d-inference-sim/simulator.go +++ b/pkg/llm-d-inference-sim/simulator.go @@ -143,7 +143,7 @@ type VllmSimulator struct { // loraAdaptors contains list of LoRA available adaptors loraAdaptors sync.Map // schema validator for tools parameters - toolsValidator *openaiserverapi.Validator + toolsValidator *common.ToolsValidator // kv cache functionality kvcacheHelper *kvcache.KVCacheHelper // namespace where simulator is running @@ -175,7 +175,7 @@ type VllmSimulator struct { // New creates a new VllmSimulator instance with the given logger func New(logger logr.Logger) (*VllmSimulator, error) { - toolsValidator, err := openaiserverapi.CreateValidator() + toolsValidator, err := common.CreateToolsValidator() if err != nil { return nil, fmt.Errorf("failed to create tools validator: %s", err) } @@ -521,12 +521,8 @@ func (s *VllmSimulator) responseSentCallback(model string, isChatCompletion bool // from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request). func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall, finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool) openaiserverapi.CompletionResponse { - baseResp := openaiserverapi.BaseCompletionResponse{ - ID: chatComplIDPrefix + common.GenerateUUIDString(), - Created: time.Now().Unix(), - Model: modelName, - Usage: usageData, - } + baseResp := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+common.GenerateUUIDString(), + time.Now().Unix(), modelName, usageData) if doRemoteDecode { // add special fields related to the prefill pod special behavior @@ -539,7 +535,7 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respToke baseResp.RemotePort = 1234 } - baseChoice := openaiserverapi.BaseResponseChoice{Index: 0, FinishReason: finishReason} + baseChoice := openaiserverapi.CreateBaseResponseChoice(0, finishReason) respText := strings.Join(respTokens, "") if isChatCompletion { @@ -551,17 +547,13 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respToke } else { message.Content = openaiserverapi.Content{Raw: respText} } - return &openaiserverapi.ChatCompletionResponse{ - BaseCompletionResponse: baseResp, - Choices: []openaiserverapi.ChatRespChoice{{Message: message, BaseResponseChoice: baseChoice}}, - } + return openaiserverapi.CreateChatCompletionResponse(baseResp, + []openaiserverapi.ChatRespChoice{openaiserverapi.CreateChatRespChoice(baseChoice, message)}) } baseResp.Object = textCompletionObject - return &openaiserverapi.TextCompletionResponse{ - BaseCompletionResponse: baseResp, - Choices: []openaiserverapi.TextRespChoice{{BaseResponseChoice: baseChoice, Text: respText}}, - } + return openaiserverapi.CreateTextCompletionResponse(baseResp, + []openaiserverapi.TextRespChoice{openaiserverapi.CreateTextRespChoice(baseChoice, respText)}) } // sendResponse sends response for completion API, supports both completions (text and chat) diff --git a/pkg/llm-d-inference-sim/streaming.go b/pkg/llm-d-inference-sim/streaming.go index fd2b6720..fac93f40 100644 --- a/pkg/llm-d-inference-sim/streaming.go +++ b/pkg/llm-d-inference-sim/streaming.go @@ -164,64 +164,40 @@ func (s *VllmSimulator) sendTokenChunks(context *streamingContext, w *bufio.Writ // createUsageChunk creates and returns a CompletionRespChunk with usage data, a single chunk of streamed completion API response, // supports both modes (text and chat) func (s *VllmSimulator) createUsageChunk(context *streamingContext, usageData *openaiserverapi.Usage) openaiserverapi.CompletionRespChunk { - baseChunk := openaiserverapi.BaseCompletionResponse{ - ID: chatComplIDPrefix + common.GenerateUUIDString(), - Created: context.creationTime, - Model: context.model, - Usage: usageData, - } + baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+common.GenerateUUIDString(), + context.creationTime, context.model, usageData) + if context.isChatCompletion { baseChunk.Object = chatCompletionChunkObject - return &openaiserverapi.ChatCompletionResponse{ - BaseCompletionResponse: baseChunk, - Choices: []openaiserverapi.ChatRespChoice{}, - } + return openaiserverapi.CreateChatCompletionResponse(baseChunk, []openaiserverapi.ChatRespChoice{}) } baseChunk.Object = textCompletionObject - return &openaiserverapi.TextCompletionResponse{ - BaseCompletionResponse: baseChunk, - Choices: []openaiserverapi.TextRespChoice{}, - } + return openaiserverapi.CreateTextCompletionResponse(baseChunk, []openaiserverapi.TextRespChoice{}) } // createTextCompletionChunk creates and returns a CompletionRespChunk, a single chunk of streamed completion API response, // for text completion func (s *VllmSimulator) createTextCompletionChunk(context *streamingContext, token string, finishReason *string) openaiserverapi.CompletionRespChunk { - return &openaiserverapi.TextCompletionResponse{ - BaseCompletionResponse: openaiserverapi.BaseCompletionResponse{ - ID: chatComplIDPrefix + common.GenerateUUIDString(), - Created: context.creationTime, - Model: context.model, - Object: textCompletionObject, - }, - Choices: []openaiserverapi.TextRespChoice{ - { - BaseResponseChoice: openaiserverapi.BaseResponseChoice{Index: 0, FinishReason: finishReason}, - Text: token, - }, - }, - } + baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+common.GenerateUUIDString(), + context.creationTime, context.model, nil) + baseChunk.Object = textCompletionObject + return openaiserverapi.CreateTextCompletionResponse(baseChunk, + []openaiserverapi.TextRespChoice{ + openaiserverapi.CreateTextRespChoice(openaiserverapi.CreateBaseResponseChoice(0, finishReason), token)}) } // createChatCompletionChunk creates and returns a CompletionRespChunk, a single chunk of streamed completion // API response, for chat completion. It sets either role, or token, or tool call info in the message. func (s *VllmSimulator) createChatCompletionChunk(context *streamingContext, token string, tool *openaiserverapi.ToolCall, role string, finishReason *string) openaiserverapi.CompletionRespChunk { - chunk := openaiserverapi.ChatCompletionRespChunk{ - BaseCompletionResponse: openaiserverapi.BaseCompletionResponse{ - ID: chatComplIDPrefix + common.GenerateUUIDString(), - Created: context.creationTime, - Model: context.model, - Object: chatCompletionChunkObject, - }, - Choices: []openaiserverapi.ChatRespChunkChoice{ - { - Delta: openaiserverapi.Message{}, - BaseResponseChoice: openaiserverapi.BaseResponseChoice{Index: 0, FinishReason: finishReason}, - }, - }, - } + baseChunk := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+common.GenerateUUIDString(), + context.creationTime, context.model, nil) + baseChunk.Object = chatCompletionChunkObject + chunk := openaiserverapi.CreateChatCompletionRespChunk(baseChunk, + []openaiserverapi.ChatRespChunkChoice{ + openaiserverapi.CreateChatRespChunkChoice( + openaiserverapi.CreateBaseResponseChoice(0, finishReason), openaiserverapi.Message{})}) if len(role) > 0 { chunk.Choices[0].Delta.Role = role diff --git a/pkg/llm-d-inference-sim/worker.go b/pkg/llm-d-inference-sim/worker.go index c79bcd02..3e05faca 100644 --- a/pkg/llm-d-inference-sim/worker.go +++ b/pkg/llm-d-inference-sim/worker.go @@ -88,10 +88,10 @@ func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx) var toolCalls []openaiserverapi.ToolCall var completionTokens int if reqCtx.IsChatCompletion && - req.GetToolChoice() != openaiserverapi.ToolChoiceNone && + req.GetToolChoice() != common.ToolChoiceNone && req.GetTools() != nil { toolCalls, completionTokens, err = - openaiserverapi.CreateToolCalls(req.GetTools(), req.GetToolChoice(), s.config) + common.CreateToolCalls(req.GetTools(), req.GetToolChoice(), s.config) finishReason = dataset.ToolsFinishReason } if toolCalls == nil && err == nil { @@ -111,9 +111,9 @@ func (s *VllmSimulator) processRequest(reqCtx *openaiserverapi.CompletionReqCtx) reqCtx.HTTPReqCtx.Error(prefix+err.Error(), fasthttp.StatusBadRequest) } else { usageData := openaiserverapi.Usage{ - PromptTokens: req.GetNumberOfPromptTokens(), + PromptTokens: s.getNumberOfPromptTokens(req), CompletionTokens: completionTokens, - TotalTokens: req.GetNumberOfPromptTokens() + completionTokens, + TotalTokens: s.getNumberOfPromptTokens(req) + completionTokens, } if req.IsStream() { var usageDataToSend *openaiserverapi.Usage diff --git a/pkg/openai-server-api/request.go b/pkg/openai-server-api/request.go index 34db0ee6..8b17d679 100644 --- a/pkg/openai-server-api/request.go +++ b/pkg/openai-server-api/request.go @@ -20,7 +20,6 @@ package openaiserverapi import ( "sync" - "github.com/llm-d/llm-d-inference-sim/pkg/common" "github.com/valyala/fasthttp" ) @@ -39,8 +38,6 @@ type CompletionRequest interface { GetModel() string // IncludeUsage returns true if usage statistics should be include in the response IncludeUsage() bool - // GetNumberOfPromptTokens returns the number of tokens in the prompt - GetNumberOfPromptTokens() int // GetNumberOfCachedPromptTokens returns the number of tokens in the prompt that are // in the local KV Cache GetNumberOfCachedPromptTokens() int @@ -67,10 +64,18 @@ type CompletionRequest interface { IsDoRemotePrefill() bool // GetFullPrompt returns the full prompt including system and user prompts GetFullPrompt() string + // ExtractPrompt extracts the prompt from the request + // for chat completion - the last user message is used as the prompt + // for text completion - the prompt field is used + ExtractPrompt() string + // ExtractMaxTokens extracts the max tokens from the request + // for chat completion - max_completion_tokens field is used + // for text completion - max_tokens field is used + ExtractMaxTokens() *int64 } -// BaseCompletionRequest contains base completion request related information -type BaseCompletionRequest struct { +// baseCompletionRequest contains base completion request related information +type baseCompletionRequest struct { // RequestID is the unique id of this request RequestID string // Stream is a boolean value, defines whether response should be sent as a Stream @@ -103,44 +108,49 @@ type StreamOptions struct { IncludeUsage bool `json:"include_usage"` } -func (b *BaseCompletionRequest) GetRequestID() string { +func (b *baseCompletionRequest) GetRequestID() string { return b.RequestID } -func (b *BaseCompletionRequest) IsStream() bool { +func (b *baseCompletionRequest) IsStream() bool { return b.Stream } -func (b *BaseCompletionRequest) GetModel() string { +func (b *baseCompletionRequest) GetModel() string { return b.Model } -func (b *BaseCompletionRequest) IncludeUsage() bool { +func (b *baseCompletionRequest) IncludeUsage() bool { return !b.Stream || b.StreamOptions.IncludeUsage } -func (b *BaseCompletionRequest) IsDoRemoteDecode() bool { +func (b *baseCompletionRequest) IsDoRemoteDecode() bool { return b.DoRemoteDecode } -func (b *BaseCompletionRequest) IsDoRemotePrefill() bool { +func (b *baseCompletionRequest) IsDoRemotePrefill() bool { return b.DoRemotePrefill } // GetNumberOfCachedPromptTokens returns the number of tokens in the prompt that are // in the local KV Cache -func (b *BaseCompletionRequest) GetNumberOfCachedPromptTokens() int { +func (b *baseCompletionRequest) GetNumberOfCachedPromptTokens() int { return b.cachedPromptTokens } // GetIgnoreEOS returns the value of IgnoreEOS -func (b *BaseCompletionRequest) GetIgnoreEOS() bool { +func (b *baseCompletionRequest) GetIgnoreEOS() bool { return b.IgnoreEOS } +// SetIgnoreEOS sets the value of IgnoreEOS +func (b *baseCompletionRequest) SetIgnoreEOS(ignorEOS bool) { + b.IgnoreEOS = ignorEOS +} + // SetNumberOfCachedPromptTokens sets the number of tokens in the prompt that are // in the local KV Cache -func (b *BaseCompletionRequest) SetNumberOfCachedPromptTokens(cachedPromptTokens int) { +func (b *baseCompletionRequest) SetNumberOfCachedPromptTokens(cachedPromptTokens int) { b.cachedPromptTokens = cachedPromptTokens } @@ -155,7 +165,7 @@ type CompletionReqCtx struct { // ChatCompletionRequest defines structure of /chat/completion request type ChatCompletionRequest struct { - BaseCompletionRequest + baseCompletionRequest // Messages list of request's Messages Messages []Message `json:"messages"` @@ -207,10 +217,6 @@ func (c *ChatCompletionRequest) GetPrompt() string { return messages } -func (c *ChatCompletionRequest) GetNumberOfPromptTokens() int { - return len(common.Tokenize(c.GetPrompt())) -} - func (c *ChatCompletionRequest) GetTools() []Tool { return c.Tools } @@ -253,10 +259,22 @@ func (req *ChatCompletionRequest) GetFullPrompt() string { return prompt } +// ExtractPrompt extracts the prompt from the request +// for chat completion - the last user message is used as the prompt +func (req *ChatCompletionRequest) ExtractPrompt() string { + return req.GetLastUserMsg() +} + +// ExtractMaxTokens extracts the max tokens from the request +// for chat completion - max_completion_tokens field is used +func (req *ChatCompletionRequest) ExtractMaxTokens() *int64 { + return req.GetMaxCompletionTokens() +} + // v1/completion // TextCompletionRequest defines structure of /completion request type TextCompletionRequest struct { - BaseCompletionRequest + baseCompletionRequest // Prompt defines request's content Prompt string `json:"prompt"` @@ -272,10 +290,6 @@ func (t *TextCompletionRequest) GetPrompt() string { return t.Prompt } -func (t *TextCompletionRequest) GetNumberOfPromptTokens() int { - return len(common.Tokenize(t.GetPrompt())) -} - func (c *TextCompletionRequest) GetTools() []Tool { return nil } @@ -291,3 +305,15 @@ func (c *TextCompletionRequest) GetMaxCompletionTokens() *int64 { func (t *TextCompletionRequest) GetFullPrompt() string { return "### user:\n" + t.Prompt + "\n" } + +// ExtractPrompt extracts the prompt from the request +// for text completion - the prompt field is used +func (req *TextCompletionRequest) ExtractPrompt() string { + return req.GetPrompt() +} + +// ExtractMaxTokens extracts the max tokens from the request +// for text completion - max_tokens field is used +func (req *TextCompletionRequest) ExtractMaxTokens() *int64 { + return req.MaxTokens +} diff --git a/pkg/openai-server-api/response.go b/pkg/openai-server-api/response.go index d32784e3..0398a858 100644 --- a/pkg/openai-server-api/response.go +++ b/pkg/openai-server-api/response.go @@ -28,8 +28,8 @@ import ( // CompletionResponse interface representing both completion response types (text and chat) type CompletionResponse interface{} -// BaseCompletionResponse contains base completion response related information -type BaseCompletionResponse struct { +// baseCompletionResponse contains base completion response related information +type baseCompletionResponse struct { // ID defines the response ID ID string `json:"id"` // Created defines the response creation timestamp @@ -66,13 +66,13 @@ type Usage struct { // ChatCompletionResponse defines structure of /chat/completion response type ChatCompletionResponse struct { - BaseCompletionResponse + baseCompletionResponse // Choices list of Choices of the response, according of OpenAI API Choices []ChatRespChoice `json:"choices"` } -// BaseResponseChoice contains base completion response's choice related information -type BaseResponseChoice struct { +// baseResponseChoice contains base completion response's choice related information +type baseResponseChoice struct { // Index defines completion response choise Index Index int `json:"index"` // FinishReason defines finish reason for response or for chunks, for not last chinks is defined as null @@ -172,21 +172,21 @@ type ToolCall struct { // ChatRespChoice represents a single chat completion response choise type ChatRespChoice struct { - BaseResponseChoice + baseResponseChoice // Message contains choice's Message Message Message `json:"message"` } // TextCompletionResponse defines structure of /completion response type TextCompletionResponse struct { - BaseCompletionResponse + baseCompletionResponse // Choices list of Choices of the response, according of OpenAI API Choices []TextRespChoice `json:"choices"` } // TextRespChoice represents a single text completion response choise type TextRespChoice struct { - BaseResponseChoice + baseResponseChoice // Text defines request's content Text string `json:"text"` } @@ -196,14 +196,14 @@ type CompletionRespChunk interface{} // ChatCompletionRespChunk is a single chat completion response chunk type ChatCompletionRespChunk struct { - BaseCompletionResponse + baseCompletionResponse // Choices list of Choices of the response, according of OpenAI API Choices []ChatRespChunkChoice `json:"choices"` } // ChatRespChunkChoice represents a single chat completion response choise in case of streaming type ChatRespChunkChoice struct { - BaseResponseChoice + baseResponseChoice // Delta is a content of the chunk Delta Message `json:"delta"` } @@ -260,3 +260,35 @@ func ErrorCodeToType(code int) string { } return errorType } + +func CreateBaseResponseChoice(index int, finishReason *string) baseResponseChoice { + return baseResponseChoice{Index: index, FinishReason: finishReason} +} + +func CreateChatRespChoice(base baseResponseChoice, message Message) ChatRespChoice { + return ChatRespChoice{baseResponseChoice: base, Message: message} +} + +func CreateChatRespChunkChoice(base baseResponseChoice, message Message) ChatRespChunkChoice { + return ChatRespChunkChoice{baseResponseChoice: base, Delta: message} +} + +func CreateTextRespChoice(base baseResponseChoice, text string) TextRespChoice { + return TextRespChoice{baseResponseChoice: base, Text: text} +} + +func CreateBaseCompletionResponse(id string, created int64, model string, usage *Usage) baseCompletionResponse { + return baseCompletionResponse{ID: id, Created: created, Model: model, Usage: usage} +} + +func CreateChatCompletionResponse(base baseCompletionResponse, choices []ChatRespChoice) *ChatCompletionResponse { + return &ChatCompletionResponse{baseCompletionResponse: base, Choices: choices} +} + +func CreateTextCompletionResponse(base baseCompletionResponse, choices []TextRespChoice) *TextCompletionResponse { + return &TextCompletionResponse{baseCompletionResponse: base, Choices: choices} +} + +func CreateChatCompletionRespChunk(base baseCompletionResponse, choices []ChatRespChunkChoice) *ChatCompletionRespChunk { + return &ChatCompletionRespChunk{baseCompletionResponse: base, Choices: choices} +}