diff --git a/openai.go b/openai.go index dc3abc8..d1cd543 100644 --- a/openai.go +++ b/openai.go @@ -15,13 +15,18 @@ import ( type ChatCompletionNewParamsWrapper struct { openai.ChatCompletionNewParams `json:""` Stream bool `json:"stream,omitempty"` + MaxCompletionTokens *int `json:"max_completion_tokens,omitempty"` } func (c ChatCompletionNewParamsWrapper) MarshalJSON() ([]byte, error) { type shadow ChatCompletionNewParamsWrapper - return param.MarshalWithExtras(c, (*shadow)(&c), map[string]any{ + extras := map[string]any{ "stream": c.Stream, - }) + } + if c.MaxCompletionTokens != nil { + extras["max_completion_tokens"] = *c.MaxCompletionTokens + } + return param.MarshalWithExtras(c, (*shadow)(&c), extras) } func (c *ChatCompletionNewParamsWrapper) UnmarshalJSON(raw []byte) error { @@ -43,6 +48,33 @@ func (c *ChatCompletionNewParamsWrapper) UnmarshalJSON(raw []byte) error { c.ChatCompletionNewParams.StreamOptions = openai.ChatCompletionStreamOptionsParam{} } + // Extract max_completion_tokens if present and positive + // OpenAI API requires positive integers for token limits + var data map[string]any + if err := json.Unmarshal(raw, &data); err == nil { + if val, exists := data["max_completion_tokens"]; exists { + // Field is explicitly set, convert to int + var tokens int + switch v := val.(type) { + case float64: + tokens = int(v) + case int: + tokens = v + case int64: + tokens = int(v) + default: + // Invalid type, skip + return nil + } + // Only set if positive (0 and negative values are invalid) + if tokens > 0 { + c.MaxCompletionTokens = &tokens + // Set it in the underlying params as well + c.ChatCompletionNewParams.MaxCompletionTokens = openai.Int(int64(tokens)) + } + } + } + return nil } diff --git a/openai_test.go b/openai_test.go index 35e52f2..b4a954a 100644 --- a/openai_test.go +++ b/openai_test.go @@ -1,6 +1,7 @@ package aibridge_test import ( + "encoding/json" "testing" "github.com/coder/aibridge" @@ -131,3 +132,88 @@ func TestOpenAILastUserPrompt(t *testing.T) { }) } } + +func TestMaxCompletionTokens(t *testing.T) { + t.Parallel() + + t.Run("unmarshal max_completion_tokens from JSON", func(t *testing.T) { + jsonStr := `{ + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], + "max_completion_tokens": 1024 + }` + + var wrapper aibridge.ChatCompletionNewParamsWrapper + err := json.Unmarshal([]byte(jsonStr), &wrapper) + require.NoError(t, err) + require.NotNil(t, wrapper.MaxCompletionTokens) + require.Equal(t, 1024, *wrapper.MaxCompletionTokens) + }) + + t.Run("unmarshal max_completion_tokens with zero value ignored", func(t *testing.T) { + jsonStr := `{ + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], + "max_completion_tokens": 0 + }` + + var wrapper aibridge.ChatCompletionNewParamsWrapper + err := json.Unmarshal([]byte(jsonStr), &wrapper) + require.NoError(t, err) + require.Nil(t, wrapper.MaxCompletionTokens, "max_completion_tokens should not be set when 0 (invalid value)") + }) + + t.Run("unmarshal max_completion_tokens with negative value ignored", func(t *testing.T) { + jsonStr := `{ + "model": "gpt-4o", + "messages": [{"role": "user", "content": "Hello"}], + "max_completion_tokens": -100 + }` + + var wrapper aibridge.ChatCompletionNewParamsWrapper + err := json.Unmarshal([]byte(jsonStr), &wrapper) + require.NoError(t, err) + require.Nil(t, wrapper.MaxCompletionTokens, "max_completion_tokens should not be set when negative (invalid value)") + }) + + t.Run("marshal max_completion_tokens to JSON", func(t *testing.T) { + maxTokens := 2048 + wrapper := aibridge.ChatCompletionNewParamsWrapper{ + ChatCompletionNewParams: openai.ChatCompletionNewParams{ + Model: openai.ChatModelGPT4o, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("Hello"), + }, + }, + MaxCompletionTokens: &maxTokens, + } + + jsonBytes, err := json.Marshal(wrapper) + require.NoError(t, err) + + var result map[string]interface{} + err = json.Unmarshal(jsonBytes, &result) + require.NoError(t, err) + require.Equal(t, float64(2048), result["max_completion_tokens"]) + }) + + t.Run("max_completion_tokens not set when nil", func(t *testing.T) { + wrapper := aibridge.ChatCompletionNewParamsWrapper{ + ChatCompletionNewParams: openai.ChatCompletionNewParams{ + Model: openai.ChatModelGPT4o, + Messages: []openai.ChatCompletionMessageParamUnion{ + openai.UserMessage("Hello"), + }, + }, + } + + jsonBytes, err := json.Marshal(wrapper) + require.NoError(t, err) + + var result map[string]interface{} + err = json.Unmarshal(jsonBytes, &result) + require.NoError(t, err) + _, exists := result["max_completion_tokens"] + require.False(t, exists, "max_completion_tokens should not be present when nil") + }) +}