diff --git a/intercept/messages/paramswrap.go b/intercept/messages/paramswrap.go index af595f9..bd5175a 100644 --- a/intercept/messages/paramswrap.go +++ b/intercept/messages/paramswrap.go @@ -6,7 +6,6 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/packages/param" - "github.com/coder/aibridge/utils" ) // MessageNewParamsWrapper exists because the "stream" param is not included in anthropic.MessageNewParams. @@ -23,18 +22,30 @@ func (b MessageNewParamsWrapper) MarshalJSON() ([]byte, error) { } func (b *MessageNewParamsWrapper) UnmarshalJSON(raw []byte) error { - convertedRaw, err := convertStringContentToArray(raw) - if err != nil { + // Parse JSON once and extract both stream field and do content conversion + // to avoid double-parsing the same payload. + var modifiedJSON map[string]any + if err := json.Unmarshal(raw, &modifiedJSON); err != nil { return err } - err = b.MessageNewParams.UnmarshalJSON(convertedRaw) + // Extract stream field from already-parsed map + if stream, ok := modifiedJSON["stream"].(bool); ok { + b.Stream = stream + } + + // Convert string content to array format if needed + if _, hasMessages := modifiedJSON["messages"]; hasMessages { + convertStringContentRecursive(modifiedJSON) + } + + // Marshal back for SDK parsing + convertedRaw, err := json.Marshal(modifiedJSON) if err != nil { return err } - b.Stream = utils.ExtractJSONField[bool](raw, "stream") - return nil + return b.MessageNewParams.UnmarshalJSON(convertedRaw) } func (b *MessageNewParamsWrapper) lastUserPrompt() (*string, error) { @@ -69,31 +80,11 @@ func (b *MessageNewParamsWrapper) lastUserPrompt() (*string, error) { return nil, nil } -// convertStringContentToArray converts string content to array format for Anthropic messages. -// https://docs.anthropic.com/en/api/messages#body-messages -// -// Each input message content may be either a single string or an array of content blocks, where each block has a -// specific type. Using a string for content is shorthand for an array of one content block of type "text". -func convertStringContentToArray(raw []byte) ([]byte, error) { - var modifiedJSON map[string]any - if err := json.Unmarshal(raw, &modifiedJSON); err != nil { - return raw, err - } - - // Check if messages exist and need content conversion - if _, hasMessages := modifiedJSON["messages"]; hasMessages { - convertStringContentRecursive(modifiedJSON) - - // Marshal back to JSON - return json.Marshal(modifiedJSON) - } - - return raw, nil -} - // convertStringContentRecursive recursively scans JSON data and converts string "content" fields -// to proper text block arrays where needed for Anthropic SDK compatibility -func convertStringContentRecursive(data any) { +// to proper text block arrays where needed for Anthropic SDK compatibility. +// Returns true if any modifications were made. +func convertStringContentRecursive(data any) bool { + modified := false switch v := data.(type) { case map[string]any: // Check if this object has a "content" field with string value @@ -107,21 +98,27 @@ func convertStringContentRecursive(data any) { "text": contentStr, }, } + modified = true } } } // Recursively process all values in the map for _, value := range v { - convertStringContentRecursive(value) + if convertStringContentRecursive(value) { + modified = true + } } case []any: // Recursively process all items in the array for _, item := range v { - convertStringContentRecursive(item) + if convertStringContentRecursive(item) { + modified = true + } } } + return modified } // shouldConvertContentField determines if a "content" string field should be converted to text block array diff --git a/intercept/messages/paramswrap_test.go b/intercept/messages/paramswrap_test.go index 992855d..7f8793d 100644 --- a/intercept/messages/paramswrap_test.go +++ b/intercept/messages/paramswrap_test.go @@ -1,128 +1,83 @@ package messages import ( - "encoding/json" "testing" "github.com/anthropics/anthropic-sdk-go" "github.com/stretchr/testify/require" ) -func TestConvertStringContentToArray(t *testing.T) { +func TestMessageNewParamsWrapperUnmarshalJSON(t *testing.T) { t.Parallel() tests := []struct { - name string - input string - expected string + name string + input string + expectedStream bool + checkContent func(t *testing.T, w *MessageNewParamsWrapper) }{ { - name: "empty json", - input: `{}`, - expected: `{}`, - }, - { - name: "message with string content", - input: `{ - "messages": [ - { - "role": "user", - "content": "Hello world" - } - ] - }`, - expected: `{"messages":[{"content":[{"text":"Hello world","type":"text"}],"role":"user"}]}`, - }, - { - name: "message with array content unchanged", - input: `{ - "messages": [ - { - "role": "user", - "content": [{"type": "text", "text": "Hello"}] - } - ] - }`, - expected: `{"messages":[{"content":[{"text":"Hello","type":"text"}],"role":"user"}]}`, + name: "message with string content converts to array", + input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":"Hello world"}]}`, + expectedStream: false, + checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { + require.Len(t, w.Messages, 1) + require.Equal(t, anthropic.MessageParamRoleUser, w.Messages[0].Role) + text := w.Messages[0].Content[0].GetText() + require.NotNil(t, text) + require.Equal(t, "Hello world", *text) + }, }, { - name: "multiple messages with mixed content", - input: `{ - "messages": [ - { - "role": "user", - "content": "First message" - }, - { - "role": "assistant", - "content": [{"type": "text", "text": "Response"}] - }, - { - "role": "user", - "content": "Second message" - } - ] - }`, - expected: `{"messages":[{"content":[{"text":"First message","type":"text"}],"role":"user"},{"content":[{"text":"Response","type":"text"}],"role":"assistant"},{"content":[{"text":"Second message","type":"text"}],"role":"user"}]}`, + name: "stream field extracted", + input: `{"model":"claude-3","max_tokens":1000,"stream":true,"messages":[{"role":"user","content":"Hi"}]}`, + expectedStream: true, + checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { + require.Len(t, w.Messages, 1) + }, }, { - name: "tool_result with string content", - input: `{ - "messages": [ - { - "role": "user", - "content": [ - { - "type": "tool_result", - "tool_use_id": "123", - "content": "Tool output" - } - ] - } - ] - }`, - expected: `{"messages":[{"content":[{"content":[{"text":"Tool output","type":"text"}],"tool_use_id":"123","type":"tool_result"}],"role":"user"}]}`, + name: "stream false", + input: `{"model":"claude-3","max_tokens":1000,"stream":false,"messages":[{"role":"user","content":"Hi"}]}`, + expectedStream: false, + checkContent: nil, }, { - name: "mcp_tool_result with string content unchanged", - input: `{ - "messages": [ - { - "role": "user", - "content": [ - { - "type": "mcp_tool_result", - "tool_use_id": "456", - "content": "MCP output" - } - ] - } - ] - }`, - expected: `{"messages":[{"content":[{"content":"MCP output","tool_use_id":"456","type":"mcp_tool_result"}],"role":"user"}]}`, + name: "array content unchanged", + input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`, + expectedStream: false, + checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { + require.Len(t, w.Messages, 1) + text := w.Messages[0].Content[0].GetText() + require.NotNil(t, text) + require.Equal(t, "Hello", *text) + }, }, { - name: "no messages field", - input: `{ - "model": "claude-3", - "max_tokens": 1000 - }`, - expected: `{"max_tokens":1000,"model":"claude-3"}`, + name: "multiple messages with mixed content", + input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":"First"},{"role":"assistant","content":[{"type":"text","text":"Response"}]},{"role":"user","content":"Second"}]}`, + expectedStream: false, + checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { + require.Len(t, w.Messages, 3) + text0 := w.Messages[0].Content[0].GetText() + require.NotNil(t, text0) + require.Equal(t, "First", *text0) + text2 := w.Messages[2].Content[0].GetText() + require.NotNil(t, text2) + require.Equal(t, "Second", *text2) + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := convertStringContentToArray([]byte(tt.input)) - require.NoError(t, err) - - var resultJSON, expectedJSON any - err = json.Unmarshal(result, &resultJSON) - require.NoError(t, err) - err = json.Unmarshal([]byte(tt.expected), &expectedJSON) + var wrapper MessageNewParamsWrapper + err := wrapper.UnmarshalJSON([]byte(tt.input)) require.NoError(t, err) - - require.Equal(t, expectedJSON, resultJSON) + require.Equal(t, tt.expectedStream, wrapper.Stream) + if tt.checkContent != nil { + tt.checkContent(t, &wrapper) + } }) } }