Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@ import (
type ChatCompletionNewParamsWrapper struct {
openai.ChatCompletionNewParams `json:""`
Stream bool `json:"stream,omitempty"`
MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"`
}

func (c ChatCompletionNewParamsWrapper) MarshalJSON() ([]byte, error) {
type shadow ChatCompletionNewParamsWrapper
return param.MarshalWithExtras(c, (*shadow)(&c), map[string]any{
extras := map[string]any{
"stream": c.Stream,
})
}
if c.MaxCompletionTokens != nil {
extras["max_completion_tokens"] = *c.MaxCompletionTokens
}
return param.MarshalWithExtras(c, (*shadow)(&c), extras)
}

func (c *ChatCompletionNewParamsWrapper) UnmarshalJSON(raw []byte) error {
Expand All @@ -43,6 +48,14 @@ func (c *ChatCompletionNewParamsWrapper) UnmarshalJSON(raw []byte) error {
c.ChatCompletionNewParams.StreamOptions = openai.ChatCompletionStreamOptionsParam{}
}

// Extract max_completion_tokens if present
if maxCompletionTokens := utils.ExtractJSONField[float64](raw, "max_completion_tokens"); maxCompletionTokens > 0 {
tokens := int(maxCompletionTokens)
c.MaxCompletionTokens = &tokens
// Set it in the underlying params as well
c.ChatCompletionNewParams.MaxCompletionTokens = openai.Int(int64(tokens))
}

return nil
}

Expand Down
60 changes: 60 additions & 0 deletions openai_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package aibridge_test

import (
"encoding/json"
"testing"

"github.com/coder/aibridge"
Expand Down Expand Up @@ -131,3 +132,62 @@ func TestOpenAILastUserPrompt(t *testing.T) {
})
}
}

func TestMaxCompletionTokens(t *testing.T) {
t.Parallel()

t.Run("unmarshal max_completion_tokens from JSON", func(t *testing.T) {
jsonStr := `{
"model": "gpt-4o",
"messages": [{"role": "user", "content": "Hello"}],
"max_completion_tokens": 1024
}`

var wrapper aibridge.ChatCompletionNewParamsWrapper
err := json.Unmarshal([]byte(jsonStr), &wrapper)
require.NoError(t, err)
require.NotNil(t, wrapper.MaxCompletionTokens)
require.Equal(t, 1024, *wrapper.MaxCompletionTokens)
})

t.Run("marshal max_completion_tokens to JSON", func(t *testing.T) {
maxTokens := 2048
wrapper := aibridge.ChatCompletionNewParamsWrapper{
ChatCompletionNewParams: openai.ChatCompletionNewParams{
Model: openai.ChatModelGPT4o,
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage("Hello"),
},
},
MaxCompletionTokens: &maxTokens,
}

jsonBytes, err := json.Marshal(wrapper)
require.NoError(t, err)

var result map[string]interface{}
err = json.Unmarshal(jsonBytes, &result)
require.NoError(t, err)
require.Equal(t, float64(2048), result["max_completion_tokens"])
})

t.Run("max_completion_tokens not set when nil", func(t *testing.T) {
wrapper := aibridge.ChatCompletionNewParamsWrapper{
ChatCompletionNewParams: openai.ChatCompletionNewParams{
Model: openai.ChatModelGPT4o,
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage("Hello"),
},
},
}

jsonBytes, err := json.Marshal(wrapper)
require.NoError(t, err)

var result map[string]interface{}
err = json.Unmarshal(jsonBytes, &result)
require.NoError(t, err)
_, exists := result["max_completion_tokens"]
require.False(t, exists, "max_completion_tokens should not be present when nil")
})
}