Skip to content

Commit 56aed1f

Browse files
authored
Change packages' dependencies (#229)
Signed-off-by: irar2 <[email protected]>
1 parent 64d8d7f commit 56aed1f

File tree

11 files changed

+169
-168
lines changed

11 files changed

+169
-168
lines changed

pkg/openai-server-api/tools_utils.go renamed to pkg/common/tools_utils.go

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ See the License for the specific language governing permissions and
1414
limitations under the License.
1515
*/
1616

17-
package openaiserverapi
17+
package common
1818

1919
import (
2020
"encoding/json"
2121
"fmt"
2222

23-
"github.com/llm-d/llm-d-inference-sim/pkg/common"
23+
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
2424
"github.com/santhosh-tekuri/jsonschema/v5"
2525
)
2626

@@ -30,7 +30,7 @@ const (
3030
ToolChoiceRequired = "required"
3131
)
3232

33-
func CountTokensForToolCalls(toolCalls []ToolCall) int {
33+
func CountTokensForToolCalls(toolCalls []openaiserverapi.ToolCall) int {
3434
numberOfTokens := 0
3535
for _, tc := range toolCalls {
3636
// 3 - name, id, and type
@@ -55,7 +55,7 @@ var fakeStringArguments = []string{
5555
// CreateToolCalls creates and returns response payload based on this request
5656
// (tool calls or nothing in case we randomly choose not to generate calls),
5757
// and the number of generated completion token sand the finish reason
58-
func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configuration) ([]ToolCall, int, error) {
58+
func CreateToolCalls(tools []openaiserverapi.Tool, toolChoice string, config *Configuration) ([]openaiserverapi.ToolCall, int, error) {
5959
// This function is called if tool choice is either 'required' or 'auto'.
6060
// In case of 'required' at least one tool call has to be created, and we randomly choose
6161
// the number of calls starting from one. Otherwise, we start from 0, and in case we randomly
@@ -64,16 +64,16 @@ func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configurati
6464
if toolChoice == ToolChoiceRequired {
6565
min = 1
6666
}
67-
numberOfCalls := common.RandomInt(min, len(tools))
67+
numberOfCalls := RandomInt(min, len(tools))
6868
if numberOfCalls == 0 {
6969
return nil, 0, nil
7070
}
7171

72-
calls := make([]ToolCall, 0)
72+
calls := make([]openaiserverapi.ToolCall, 0)
7373
for i := range numberOfCalls {
7474
// Randomly choose which tools to call. We may call the same tool more than once.
75-
index := common.RandomInt(0, len(tools)-1)
76-
args, err := GenerateToolArguments(tools[index], config)
75+
index := RandomInt(0, len(tools)-1)
76+
args, err := generateToolArguments(tools[index], config)
7777
if err != nil {
7878
return nil, 0, err
7979
}
@@ -82,13 +82,13 @@ func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configurati
8282
return nil, 0, err
8383
}
8484

85-
call := ToolCall{
86-
Function: FunctionCall{
85+
call := openaiserverapi.ToolCall{
86+
Function: openaiserverapi.FunctionCall{
8787
Arguments: string(argsJson),
88-
TokenizedArguments: common.Tokenize(string(argsJson)),
88+
TokenizedArguments: Tokenize(string(argsJson)),
8989
Name: &tools[index].Function.Name,
9090
},
91-
ID: "chatcmpl-tool-" + common.RandomNumericString(10),
91+
ID: "chatcmpl-tool-" + RandomNumericString(10),
9292
Type: "function",
9393
Index: i,
9494
}
@@ -98,7 +98,7 @@ func CreateToolCalls(tools []Tool, toolChoice string, config *common.Configurati
9898
return calls, CountTokensForToolCalls(calls), nil
9999
}
100100

101-
func GetRequiredAsMap(property map[string]any) map[string]struct{} {
101+
func getRequiredAsMap(property map[string]any) map[string]struct{} {
102102
required := make(map[string]struct{})
103103
requiredParams, ok := property["required"]
104104
if ok {
@@ -111,18 +111,18 @@ func GetRequiredAsMap(property map[string]any) map[string]struct{} {
111111
return required
112112
}
113113

114-
func GenerateToolArguments(tool Tool, config *common.Configuration) (map[string]any, error) {
114+
func generateToolArguments(tool openaiserverapi.Tool, config *Configuration) (map[string]any, error) {
115115
arguments := make(map[string]any)
116116
properties, _ := tool.Function.Parameters["properties"].(map[string]any)
117117

118-
required := GetRequiredAsMap(tool.Function.Parameters)
118+
required := getRequiredAsMap(tool.Function.Parameters)
119119

120120
for param, property := range properties {
121121
_, paramIsRequired := required[param]
122-
if !paramIsRequired && !common.RandomBool(config.ToolCallNotRequiredParamProbability) {
122+
if !paramIsRequired && !RandomBool(config.ToolCallNotRequiredParamProbability) {
123123
continue
124124
}
125-
arg, err := CreateArgument(property, config)
125+
arg, err := createArgument(property, config)
126126
if err != nil {
127127
return nil, err
128128
}
@@ -132,7 +132,7 @@ func GenerateToolArguments(tool Tool, config *common.Configuration) (map[string]
132132
return arguments, nil
133133
}
134134

135-
func CreateArgument(property any, config *common.Configuration) (any, error) {
135+
func createArgument(property any, config *Configuration) (any, error) {
136136
propertyMap, _ := property.(map[string]any)
137137
paramType := propertyMap["type"]
138138

@@ -141,20 +141,20 @@ func CreateArgument(property any, config *common.Configuration) (any, error) {
141141
if ok {
142142
enumArray, ok := enum.([]any)
143143
if ok && len(enumArray) > 0 {
144-
index := common.RandomInt(0, len(enumArray)-1)
144+
index := RandomInt(0, len(enumArray)-1)
145145
return enumArray[index], nil
146146
}
147147
}
148148

149149
switch paramType {
150150
case "string":
151-
return GetStringArgument(), nil
151+
return getStringArgument(), nil
152152
case "integer":
153-
return common.RandomInt(config.MinToolCallIntegerParam, config.MaxToolCallIntegerParam), nil
153+
return RandomInt(config.MinToolCallIntegerParam, config.MaxToolCallIntegerParam), nil
154154
case "number":
155-
return common.RandomFloat(config.MinToolCallNumberParam, config.MaxToolCallNumberParam), nil
155+
return RandomFloat(config.MinToolCallNumberParam, config.MaxToolCallNumberParam), nil
156156
case "boolean":
157-
return common.FlipCoin(), nil
157+
return FlipCoin(), nil
158158
case "array":
159159
items := propertyMap["items"]
160160
itemsMap := items.(map[string]any)
@@ -169,26 +169,26 @@ func CreateArgument(property any, config *common.Configuration) (any, error) {
169169
if minItems > maxItems {
170170
return nil, fmt.Errorf("minItems (%d) is greater than maxItems(%d)", minItems, maxItems)
171171
}
172-
numberOfElements := common.RandomInt(minItems, maxItems)
172+
numberOfElements := RandomInt(minItems, maxItems)
173173
array := make([]any, numberOfElements)
174174
for i := range numberOfElements {
175-
elem, err := CreateArgument(itemsMap, config)
175+
elem, err := createArgument(itemsMap, config)
176176
if err != nil {
177177
return nil, err
178178
}
179179
array[i] = elem
180180
}
181181
return array, nil
182182
case "object":
183-
required := GetRequiredAsMap(propertyMap)
183+
required := getRequiredAsMap(propertyMap)
184184
objectProperties := propertyMap["properties"].(map[string]any)
185185
object := make(map[string]interface{})
186186
for fieldName, fieldProperties := range objectProperties {
187187
_, fieldIsRequired := required[fieldName]
188-
if !fieldIsRequired && !common.RandomBool(config.ObjectToolCallNotRequiredParamProbability) {
188+
if !fieldIsRequired && !RandomBool(config.ObjectToolCallNotRequiredParamProbability) {
189189
continue
190190
}
191-
fieldValue, err := CreateArgument(fieldProperties, config)
191+
fieldValue, err := createArgument(fieldProperties, config)
192192
if err != nil {
193193
return nil, err
194194
}
@@ -200,24 +200,24 @@ func CreateArgument(property any, config *common.Configuration) (any, error) {
200200
}
201201
}
202202

203-
func GetStringArgument() string {
204-
index := common.RandomInt(0, len(fakeStringArguments)-1)
203+
func getStringArgument() string {
204+
index := RandomInt(0, len(fakeStringArguments)-1)
205205
return fakeStringArguments[index]
206206
}
207207

208-
type Validator struct {
208+
type ToolsValidator struct {
209209
schema *jsonschema.Schema
210210
}
211211

212-
func CreateValidator() (*Validator, error) {
212+
func CreateToolsValidator() (*ToolsValidator, error) {
213213
sch, err := jsonschema.CompileString("schema.json", schema)
214214
if err != nil {
215215
return nil, err
216216
}
217-
return &Validator{schema: sch}, nil
217+
return &ToolsValidator{schema: sch}, nil
218218
}
219219

220-
func (v *Validator) ValidateTool(tool []byte) error {
220+
func (v *ToolsValidator) ValidateTool(tool []byte) error {
221221
var value interface{}
222222
if err := json.Unmarshal(tool, &value); err != nil {
223223
return err

pkg/dataset/custom_dataset.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ func (d *CustomDataset) GetTokens(req openaiserverapi.CompletionRequest, mode st
439439
if mode == common.ModeEcho {
440440
return d.echo(req)
441441
}
442-
nTokensToGen, finishReason := howManyTokensToGen(d.extractMaxTokens(req), req.GetIgnoreEOS())
442+
nTokensToGen, finishReason := howManyTokensToGen(req.ExtractMaxTokens(), req.GetIgnoreEOS())
443443
tokens, err := d.GenerateTokens(req, nTokensToGen, finishReason)
444444
return tokens, finishReason, err
445445
}

pkg/dataset/dataset.go

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ package dataset
1818

1919
import (
2020
"context"
21-
"errors"
2221
"math"
2322
"math/rand"
2423

@@ -291,12 +290,7 @@ func (d *BaseDataset) Close() error {
291290
}
292291

293292
func (d *BaseDataset) echo(req openaiserverapi.CompletionRequest) ([]string, string, error) {
294-
nMaxTokens := d.extractMaxTokens(req)
295-
prompt, err := d.extractPrompt(req)
296-
if err != nil {
297-
return nil, "", err
298-
}
299-
tokens, finishReason := EchoResponseTokens(nMaxTokens, prompt)
293+
tokens, finishReason := EchoResponseTokens(req.ExtractMaxTokens(), req.ExtractPrompt())
300294
return tokens, finishReason, nil
301295
}
302296

@@ -305,30 +299,6 @@ func (d *BaseDataset) GetTokens(req openaiserverapi.CompletionRequest, mode stri
305299
if mode == common.ModeEcho {
306300
return d.echo(req)
307301
}
308-
nTokensToGen, finishReason := howManyTokensToGen(d.extractMaxTokens(req), req.GetIgnoreEOS())
302+
nTokensToGen, finishReason := howManyTokensToGen(req.ExtractMaxTokens(), req.GetIgnoreEOS())
309303
return GenPresetRandomTokens(nTokensToGen), finishReason, nil
310304
}
311-
312-
// extractMaxTokens extracts the max tokens from the request
313-
// for chat completion - max_completion_tokens field is used
314-
// for text completion - max_tokens field is used
315-
func (d *BaseDataset) extractMaxTokens(req openaiserverapi.CompletionRequest) *int64 {
316-
if chatReq, ok := req.(*openaiserverapi.ChatCompletionRequest); ok {
317-
return chatReq.GetMaxCompletionTokens()
318-
} else if textReq, ok := req.(*openaiserverapi.TextCompletionRequest); ok {
319-
return textReq.MaxTokens
320-
}
321-
return nil
322-
}
323-
324-
// extractPrompt extracts the prompt from the request
325-
// for chat completion - the last user message is used as the prompt
326-
// for text completion - the prompt field is used
327-
func (d *BaseDataset) extractPrompt(req openaiserverapi.CompletionRequest) (string, error) {
328-
if chatReq, ok := req.(*openaiserverapi.ChatCompletionRequest); ok {
329-
return chatReq.GetLastUserMsg(), nil
330-
} else if textReq, ok := req.(*openaiserverapi.TextCompletionRequest); ok {
331-
return textReq.GetPrompt(), nil
332-
}
333-
return "", errors.New("unknown request type")
334-
}

pkg/dataset/dataset_test.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,9 @@ var _ = Describe("Dataset", Ordered, func() {
9292
func(maxCompletionTokens int) {
9393
n := int64(maxCompletionTokens)
9494
req := &openaiserverapi.ChatCompletionRequest{
95-
BaseCompletionRequest: openaiserverapi.BaseCompletionRequest{
96-
IgnoreEOS: true,
97-
},
9895
MaxTokens: &n,
9996
}
97+
req.SetIgnoreEOS(true)
10098
tokens, finishReason, err := dataset.GetTokens(req, common.ModeRandom)
10199
Expect(err).ShouldNot(HaveOccurred())
102100
nGenTokens := int64(len(tokens))

pkg/llm-d-inference-sim/helpers.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ package llmdinferencesim
2020
import (
2121
"encoding/json"
2222
"fmt"
23+
24+
"github.com/llm-d/llm-d-inference-sim/pkg/common"
25+
openaiserverapi "github.com/llm-d/llm-d-inference-sim/pkg/openai-server-api"
2326
)
2427

2528
// isValidModel checks if the given model is the base model or one of "loaded" LoRAs
@@ -92,3 +95,7 @@ func (s *VllmSimulator) showConfig(dp bool) error {
9295
s.logger.Info("Configuration:", "", string(cfgJSON))
9396
return nil
9497
}
98+
99+
func (s *VllmSimulator) getNumberOfPromptTokens(req openaiserverapi.CompletionRequest) int {
100+
return len(common.Tokenize(req.GetPrompt()))
101+
}

pkg/llm-d-inference-sim/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ func (s *VllmSimulator) validateRequest(req openaiserverapi.CompletionRequest) (
233233
}
234234

235235
// Validate context window constraints
236-
promptTokens := req.GetNumberOfPromptTokens()
236+
promptTokens := s.getNumberOfPromptTokens(req)
237237
completionTokens := req.GetMaxCompletionTokens()
238238
isValid, actualCompletionTokens, totalTokens := common.ValidateContextWindow(promptTokens, completionTokens, s.config.MaxModelLen)
239239
if !isValid {

pkg/llm-d-inference-sim/simulator.go

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ type VllmSimulator struct {
143143
// loraAdaptors contains list of LoRA available adaptors
144144
loraAdaptors sync.Map
145145
// schema validator for tools parameters
146-
toolsValidator *openaiserverapi.Validator
146+
toolsValidator *common.ToolsValidator
147147
// kv cache functionality
148148
kvcacheHelper *kvcache.KVCacheHelper
149149
// namespace where simulator is running
@@ -175,7 +175,7 @@ type VllmSimulator struct {
175175

176176
// New creates a new VllmSimulator instance with the given logger
177177
func New(logger logr.Logger) (*VllmSimulator, error) {
178-
toolsValidator, err := openaiserverapi.CreateValidator()
178+
toolsValidator, err := common.CreateToolsValidator()
179179
if err != nil {
180180
return nil, fmt.Errorf("failed to create tools validator: %s", err)
181181
}
@@ -521,12 +521,8 @@ func (s *VllmSimulator) responseSentCallback(model string, isChatCompletion bool
521521
// from --served-model-name (for a base-model request) or the LoRA adapter name (for a LoRA request).
522522
func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respTokens []string, toolCalls []openaiserverapi.ToolCall,
523523
finishReason *string, usageData *openaiserverapi.Usage, modelName string, doRemoteDecode bool) openaiserverapi.CompletionResponse {
524-
baseResp := openaiserverapi.BaseCompletionResponse{
525-
ID: chatComplIDPrefix + common.GenerateUUIDString(),
526-
Created: time.Now().Unix(),
527-
Model: modelName,
528-
Usage: usageData,
529-
}
524+
baseResp := openaiserverapi.CreateBaseCompletionResponse(chatComplIDPrefix+common.GenerateUUIDString(),
525+
time.Now().Unix(), modelName, usageData)
530526

531527
if doRemoteDecode {
532528
// add special fields related to the prefill pod special behavior
@@ -539,7 +535,7 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respToke
539535
baseResp.RemotePort = 1234
540536
}
541537

542-
baseChoice := openaiserverapi.BaseResponseChoice{Index: 0, FinishReason: finishReason}
538+
baseChoice := openaiserverapi.CreateBaseResponseChoice(0, finishReason)
543539

544540
respText := strings.Join(respTokens, "")
545541
if isChatCompletion {
@@ -551,17 +547,13 @@ func (s *VllmSimulator) createCompletionResponse(isChatCompletion bool, respToke
551547
} else {
552548
message.Content = openaiserverapi.Content{Raw: respText}
553549
}
554-
return &openaiserverapi.ChatCompletionResponse{
555-
BaseCompletionResponse: baseResp,
556-
Choices: []openaiserverapi.ChatRespChoice{{Message: message, BaseResponseChoice: baseChoice}},
557-
}
550+
return openaiserverapi.CreateChatCompletionResponse(baseResp,
551+
[]openaiserverapi.ChatRespChoice{openaiserverapi.CreateChatRespChoice(baseChoice, message)})
558552
}
559553

560554
baseResp.Object = textCompletionObject
561-
return &openaiserverapi.TextCompletionResponse{
562-
BaseCompletionResponse: baseResp,
563-
Choices: []openaiserverapi.TextRespChoice{{BaseResponseChoice: baseChoice, Text: respText}},
564-
}
555+
return openaiserverapi.CreateTextCompletionResponse(baseResp,
556+
[]openaiserverapi.TextRespChoice{openaiserverapi.CreateTextRespChoice(baseChoice, respText)})
565557
}
566558

567559
// sendResponse sends response for completion API, supports both completions (text and chat)

0 commit comments

Comments
 (0)