Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 34 additions & 34 deletions pkg/openai-server-api/tools_utils.go → pkg/common/tools_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ See the License for the specific language governing permissions and
limitations under the License.
*/

package openaiserverapi
package common

import (
"encoding/json"
"fmt"

"github.com/llm-d/llm-d-inference-sim/pkg/common"
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
"github.com/santhosh-tekuri/jsonschema/v5"
)

Expand All @@ -30,7 +30,7 @@ const (
ToolChoiceRequired = "required"
)

func CountTokensForToolCalls(toolCalls []ToolCall) int {
func CountTokensForToolCalls(toolCalls []openaiserverapi.ToolCall) int {
numberOfTokens := 0
for _, tc := range toolCalls {
// 3 - name, id, and type
Expand All @@ -55,7 +55,7 @@ var fakeStringArguments = []string{
// CreateToolCalls creates and returns response payload based on this request
// (tool calls or nothing in case we randomly choose not to generate calls),
// and the number of generated completion token sand the finish reason
func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configuration) ([]ToolCall, int, error) {
func CreateToolCalls(tools []openaiserverapi.Tool, toolChoice string, config *Configuration) ([]openaiserverapi.ToolCall, int, error) {
// This function is called if tool choice is either 'required' or 'auto'.
// In case of 'required' at least one tool call has to be created, and we randomly choose
// the number of calls starting from one. Otherwise, we start from 0, and in case we randomly
Expand All @@ -64,16 +64,16 @@ func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configurati
if toolChoice == ToolChoiceRequired {
min = 1
}
numberOfCalls := common.RandomInt(min, len(tools))
numberOfCalls := RandomInt(min, len(tools))
if numberOfCalls == 0 {
return nil, 0, nil
}

calls := make([]ToolCall, 0)
calls := make([]openaiserverapi.ToolCall, 0)
for i := range numberOfCalls {
// Randomly choose which tools to call. We may call the same tool more than once.
index := common.RandomInt(0, len(tools)-1)
args, err := GenerateToolArguments(tools[index], config)
index := RandomInt(0, len(tools)-1)
args, err := generateToolArguments(tools[index], config)
if err != nil {
return nil, 0, err
}
Expand All @@ -82,13 +82,13 @@ func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configurati
return nil, 0, err
}

call := ToolCall{
Function: FunctionCall{
call := openaiserverapi.ToolCall{
Function: openaiserverapi.FunctionCall{
Arguments: string(argsJson),
TokenizedArguments: common.Tokenize(string(argsJson)),
TokenizedArguments: Tokenize(string(argsJson)),
Name: &tools[index].Function.Name,
},
ID: "chatcmpl-tool-" + common.RandomNumericString(10),
ID: "chatcmpl-tool-" + RandomNumericString(10),
Type: "function",
Index: i,
}
Expand All @@ -98,7 +98,7 @@ func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configurati
return calls, CountTokensForToolCalls(calls), nil
}

func GetRequiredAsMap(property map[string]any) map[string]struct{} {
func getRequiredAsMap(property map[string]any) map[string]struct{} {
required := make(map[string]struct{})
requiredParams, ok := property["required"]
if ok {
Expand All @@ -111,18 +111,18 @@ func GetRequiredAsMap(property map[string]any) map[string]struct{} {
return required
}

func GenerateToolArguments(tool Tool, config *common.Configuration) (map[string]any, error) {
func generateToolArguments(tool openaiserverapi.Tool, config *Configuration) (map[string]any, error) {
arguments := make(map[string]any)
properties, _ := tool.Function.Parameters["properties"].(map[string]any)

required := GetRequiredAsMap(tool.Function.Parameters)
required := getRequiredAsMap(tool.Function.Parameters)

for param, property := range properties {
_, paramIsRequired := required[param]
if !paramIsRequired && !common.RandomBool(config.ToolCallNotRequiredParamProbability) {
if !paramIsRequired && !RandomBool(config.ToolCallNotRequiredParamProbability) {
continue
}
arg, err := CreateArgument(property, config)
arg, err := createArgument(property, config)
if err != nil {
return nil, err
}
Expand All @@ -132,7 +132,7 @@ func GenerateToolArguments(tool Tool, config *common.Configuration) (map[string]
return arguments, nil
}

func CreateArgument(property any, config *common.Configuration) (any, error) {
func createArgument(property any, config *Configuration) (any, error) {
propertyMap, _ := property.(map[string]any)
paramType := propertyMap["type"]

Expand All @@ -141,20 +141,20 @@ func CreateArgument(property any, config *common.Configuration) (any, error) {
if ok {
enumArray, ok := enum.([]any)
if ok && len(enumArray) > 0 {
index := common.RandomInt(0, len(enumArray)-1)
index := RandomInt(0, len(enumArray)-1)
return enumArray[index], nil
}
}

switch paramType {
case "string":
return GetStringArgument(), nil
return getStringArgument(), nil
case "integer":
return common.RandomInt(config.MinToolCallIntegerParam, config.MaxToolCallIntegerParam), nil
return RandomInt(config.MinToolCallIntegerParam, config.MaxToolCallIntegerParam), nil
case "number":
return common.RandomFloat(config.MinToolCallNumberParam, config.MaxToolCallNumberParam), nil
return RandomFloat(config.MinToolCallNumberParam, config.MaxToolCallNumberParam), nil
case "boolean":
return common.FlipCoin(), nil
return FlipCoin(), nil
case "array":
items := propertyMap["items"]
itemsMap := items.(map[string]any)
Expand All @@ -169,26 +169,26 @@ func CreateArgument(property any, config *common.Configuration) (any, error) {
if minItems > maxItems {
return nil, fmt.Errorf("minItems (%d) is greater than maxItems(%d)", minItems, maxItems)
}
numberOfElements := common.RandomInt(minItems, maxItems)
numberOfElements := RandomInt(minItems, maxItems)
array := make([]any, numberOfElements)
for i := range numberOfElements {
elem, err := CreateArgument(itemsMap, config)
elem, err := createArgument(itemsMap, config)
if err != nil {
return nil, err
}
array[i] = elem
}
return array, nil
case "object":
required := GetRequiredAsMap(propertyMap)
required := getRequiredAsMap(propertyMap)
objectProperties := propertyMap["properties"].(map[string]any)
object := make(map[string]interface{})
for fieldName, fieldProperties := range objectProperties {
_, fieldIsRequired := required[fieldName]
if !fieldIsRequired && !common.RandomBool(config.ObjectToolCallNotRequiredParamProbability) {
if !fieldIsRequired && !RandomBool(config.ObjectToolCallNotRequiredParamProbability) {
continue
}
fieldValue, err := CreateArgument(fieldProperties, config)
fieldValue, err := createArgument(fieldProperties, config)
if err != nil {
return nil, err
}
Expand All @@ -200,24 +200,24 @@ func CreateArgument(property any, config *common.Configuration) (any, error) {
}
}

func GetStringArgument() string {
index := common.RandomInt(0, len(fakeStringArguments)-1)
func getStringArgument() string {
index := RandomInt(0, len(fakeStringArguments)-1)
return fakeStringArguments[index]
}

type Validator struct {
type ToolsValidator struct {
schema *jsonschema.Schema
}

func CreateValidator() (*Validator, error) {
func CreateToolsValidator() (*ToolsValidator, error) {
sch, err := jsonschema.CompileString("schema.json", schema)
if err != nil {
return nil, err
}
return &Validator{schema: sch}, nil
return &ToolsValidator{schema: sch}, nil
}

func (v *Validator) ValidateTool(tool []byte) error {
func (v *ToolsValidator) ValidateTool(tool []byte) error {
var value interface{}
if err := json.Unmarshal(tool, &value); err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion pkg/dataset/custom_dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode st
if mode == common.ModeEcho {
return d.echo(req)
}
nTokensToGen, finishReason := howManyTokensToGen(d.extractMaxTokens(req), req.GetIgnoreEOS())
nTokensToGen, finishReason := howManyTokensToGen(req.ExtractMaxTokens(), req.GetIgnoreEOS())
tokens, err := d.GenerateTokens(req, nTokensToGen, finishReason)
return tokens, finishReason, err
}
Expand Down
34 changes: 2 additions & 32 deletions pkg/dataset/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package dataset

import (
"context"
"errors"
"math"
"math/rand"

Expand Down Expand Up @@ -291,12 +290,7 @@ func (d *BaseDataset) Close() error {
}

func (d *BaseDataset) echo(req openaiserverapi.CompletionRequest) ([]string, string, error) {
nMaxTokens := d.extractMaxTokens(req)
prompt, err := d.extractPrompt(req)
if err != nil {
return nil, "", err
}
tokens, finishReason := EchoResponseTokens(nMaxTokens, prompt)
tokens, finishReason := EchoResponseTokens(req.ExtractMaxTokens(), req.ExtractPrompt())
return tokens, finishReason, nil
}

Expand All @@ -305,30 +299,6 @@ func (d *BaseDataset) GetTokens(req openaiserverapi.CompletionRequest, mode stri
if mode == common.ModeEcho {
return d.echo(req)
}
nTokensToGen, finishReason := howManyTokensToGen(d.extractMaxTokens(req), req.GetIgnoreEOS())
nTokensToGen, finishReason := howManyTokensToGen(req.ExtractMaxTokens(), req.GetIgnoreEOS())
return GenPresetRandomTokens(nTokensToGen), finishReason, nil
}

// extractMaxTokens extracts the max tokens from the request
// for chat completion - max_completion_tokens field is used
// for text completion - max_tokens field is used
func (d *BaseDataset) extractMaxTokens(req openaiserverapi.CompletionRequest) *int64 {
if chatReq, ok := req.(*openaiserverapi.ChatCompletionRequest); ok {
return chatReq.GetMaxCompletionTokens()
} else if textReq, ok := req.(*openaiserverapi.TextCompletionRequest); ok {
return textReq.MaxTokens
}
return nil
}

// extractPrompt extracts the prompt from the request
// for chat completion - the last user message is used as the prompt
// for text completion - the prompt field is used
func (d *BaseDataset) extractPrompt(req openaiserverapi.CompletionRequest) (string, error) {
if chatReq, ok := req.(*openaiserverapi.ChatCompletionRequest); ok {
return chatReq.GetLastUserMsg(), nil
} else if textReq, ok := req.(*openaiserverapi.TextCompletionRequest); ok {
return textReq.GetPrompt(), nil
}
return "", errors.New("unknown request type")
}
4 changes: 1 addition & 3 deletions pkg/dataset/dataset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,9 @@ var _ = Describe("Dataset", Ordered, func() {
func(maxCompletionTokens int) {
n := int64(maxCompletionTokens)
req := &openaiserverapi.ChatCompletionRequest{
BaseCompletionRequest: openaiserverapi.BaseCompletionRequest{
IgnoreEOS: true,
},
MaxTokens: &n,
}
req.SetIgnoreEOS(true)
tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom)
Expect(err).ShouldNot(HaveOccurred())
nGenTokens := int64(len(tokens))
Expand Down
7 changes: 7 additions & 0 deletions pkg/llm-d-inference-sim/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ package llmdinferencesim
import (
"encoding/json"
"fmt"

"github.com/llm-d/llm-d-inference-sim/pkg/common"
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
)

// isValidModel checks if the given model is the base model or one of "loaded" LoRAs
Expand Down Expand Up @@ -92,3 +95,7 @@ func (s *VllmSimulator) showConfig(dp bool) error {
s.logger.Info("Configuration:", "", string(cfgJSON))
return nil
}

func (s *VllmSimulator) getNumberOfPromptTokens(req openaiserverapi.CompletionRequest) int {
return len(common.Tokenize(req.GetPrompt()))
}
2 changes: 1 addition & 1 deletion pkg/llm-d-inference-sim/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ func (s *VllmSimulator) validateRequest(req openaiserverapi.CompletionRequest) (
}

// Validate context window constraints
promptTokens := req.GetNumberOfPromptTokens()
promptTokens := s.getNumberOfPromptTokens(req)
completionTokens := req.GetMaxCompletionTokens()
isValid, actualCompletionTokens, totalTokens := common.ValidateContextWindow(promptTokens, completionTokens, s.config.MaxModelLen)
if !isValid {
Expand Down
26 changes: 9 additions & 17 deletions pkg/llm-d-inference-sim/simulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ type VllmSimulator struct {
// loraAdaptors contains list of LoRA available adaptors
loraAdaptors sync.Map
// schema validator for tools parameters
toolsValidator *openaiserverapi.Validator
toolsValidator *common.ToolsValidator
// kv cache functionality
kvcacheHelper *kvcache.KVCacheHelper
// namespace where simulator is running
Expand Down Expand Up @@ -175,7 +175,7 @@ type VllmSimulator struct {

// New creates a new VllmSimulator instance with the given logger
func New(logger logr.Logger) (*VllmSimulator, error) {
toolsValidator, err := openaiserverapi.CreateValidator()
toolsValidator, err := common.CreateToolsValidator()
if err != nil {
return nil, fmt.Errorf("failed to create tools validator: %s", err)
}
Expand Down Expand Up @@ -521,12 +521,8 @@ func (s *VllmSimulator) responseSentCallback(model string, isChatCompletion bool
// from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request).
func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall,
finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool) openaiserverapi.CompletionResponse {
baseResp := openaiserverapi.BaseCompletionResponse{
ID: chatComplIDPrefix + common.GenerateUUIDString(),
Created: time.Now().Unix(),
Model: modelName,
Usage: usageData,
}
baseResp := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+common.GenerateUUIDString(),
time.Now().Unix(), modelName, usageData)

if doRemoteDecode {
// add special fields related to the prefill pod special behavior
Expand All @@ -539,7 +535,7 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respToke
baseResp.RemotePort = 1234
}

baseChoice := openaiserverapi.BaseResponseChoice{Index: 0, FinishReason: finishReason}
baseChoice := openaiserverapi.CreateBaseResponseChoice(0, finishReason)

respText := strings.Join(respTokens, "")
if isChatCompletion {
Expand All @@ -551,17 +547,13 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respToke
} else {
message.Content = openaiserverapi.Content{Raw: respText}
}
return &openaiserverapi.ChatCompletionResponse{
BaseCompletionResponse: baseResp,
Choices: []openaiserverapi.ChatRespChoice{{Message: message, BaseResponseChoice: baseChoice}},
}
return openaiserverapi.CreateChatCompletionResponse(baseResp,
[]openaiserverapi.ChatRespChoice{openaiserverapi.CreateChatRespChoice(baseChoice, message)})
}

baseResp.Object = textCompletionObject
return &openaiserverapi.TextCompletionResponse{
BaseCompletionResponse: baseResp,
Choices: []openaiserverapi.TextRespChoice{{BaseResponseChoice: baseChoice, Text: respText}},
}
return openaiserverapi.CreateTextCompletionResponse(baseResp,
[]openaiserverapi.TextRespChoice{openaiserverapi.CreateTextRespChoice(baseChoice, respText)})
}

// sendResponse sends response for completion API, supports both completions (text and chat)
Expand Down
Loading