Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 29 additions & 32 deletions intercept/messages/paramswrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (

"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/packages/param"
"github.com/coder/aibridge/utils"
)

// MessageNewParamsWrapper exists because the "stream" param is not included in anthropic.MessageNewParams.
Expand All @@ -23,18 +22,30 @@ func (b MessageNewParamsWrapper) MarshalJSON() ([]byte, error) {
}

func (b *MessageNewParamsWrapper) UnmarshalJSON(raw []byte) error {
convertedRaw, err := convertStringContentToArray(raw)
if err != nil {
// Parse JSON once and extract both stream field and do content conversion
// to avoid double-parsing the same payload.
var modifiedJSON map[string]any
if err := json.Unmarshal(raw, &modifiedJSON); err != nil {
return err
}

err = b.MessageNewParams.UnmarshalJSON(convertedRaw)
// Extract stream field from already-parsed map
if stream, ok := modifiedJSON["stream"].(bool); ok {
b.Stream = stream
}

// Convert string content to array format if needed
if _, hasMessages := modifiedJSON["messages"]; hasMessages {
convertStringContentRecursive(modifiedJSON)
}

// Marshal back for SDK parsing
convertedRaw, err := json.Marshal(modifiedJSON)
if err != nil {
return err
}

b.Stream = utils.ExtractJSONField[bool](raw, "stream")
return nil
return b.MessageNewParams.UnmarshalJSON(convertedRaw)
}

func (b *MessageNewParamsWrapper) lastUserPrompt() (*string, error) {
Expand Down Expand Up @@ -69,31 +80,11 @@ func (b *MessageNewParamsWrapper) lastUserPrompt() (*string, error) {
return nil, nil
}

// convertStringContentToArray converts string content to array format for Anthropic messages.
// https://docs.anthropic.com/en/api/messages#body-messages
//
// Each input message content may be either a single string or an array of content blocks, where each block has a
// specific type. Using a string for content is shorthand for an array of one content block of type "text".
func convertStringContentToArray(raw []byte) ([]byte, error) {
var modifiedJSON map[string]any
if err := json.Unmarshal(raw, &modifiedJSON); err != nil {
return raw, err
}

// Check if messages exist and need content conversion
if _, hasMessages := modifiedJSON["messages"]; hasMessages {
convertStringContentRecursive(modifiedJSON)

// Marshal back to JSON
return json.Marshal(modifiedJSON)
}

return raw, nil
}

// convertStringContentRecursive recursively scans JSON data and converts string "content" fields
// to proper text block arrays where needed for Anthropic SDK compatibility
func convertStringContentRecursive(data any) {
// to proper text block arrays where needed for Anthropic SDK compatibility.
// Returns true if any modifications were made.
func convertStringContentRecursive(data any) bool {
modified := false
switch v := data.(type) {
case map[string]any:
// Check if this object has a "content" field with string value
Expand All @@ -107,21 +98,27 @@ func convertStringContentRecursive(data any) {
"text": contentStr,
},
}
modified = true
}
}
}

// Recursively process all values in the map
for _, value := range v {
convertStringContentRecursive(value)
if convertStringContentRecursive(value) {
modified = true
}
}

case []any:
// Recursively process all items in the array
for _, item := range v {
convertStringContentRecursive(item)
if convertStringContentRecursive(item) {
modified = true
}
}
}
return modified
}

// shouldConvertContentField determines if a "content" string field should be converted to text block array
Expand Down
149 changes: 52 additions & 97 deletions intercept/messages/paramswrap_test.go
Original file line number Diff line number Diff line change
@@ -1,128 +1,83 @@
package messages

import (
"encoding/json"
"testing"

"github.com/anthropics/anthropic-sdk-go"
"github.com/stretchr/testify/require"
)

func TestConvertStringContentToArray(t *testing.T) {
func TestMessageNewParamsWrapperUnmarshalJSON(t *testing.T) {
t.Parallel()

tests := []struct {
name string
input string
expected string
name string
input string
expectedStream bool
checkContent func(t *testing.T, w *MessageNewParamsWrapper)
}{
{
name: "empty json",
input: `{}`,
expected: `{}`,
},
{
name: "message with string content",
input: `{
"messages": [
{
"role": "user",
"content": "Hello world"
}
]
}`,
expected: `{"messages":[{"content":[{"text":"Hello world","type":"text"}],"role":"user"}]}`,
},
{
name: "message with array content unchanged",
input: `{
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Hello"}]
}
]
}`,
expected: `{"messages":[{"content":[{"text":"Hello","type":"text"}],"role":"user"}]}`,
name: "message with string content converts to array",
input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":"Hello world"}]}`,
expectedStream: false,
checkContent: func(t *testing.T, w *MessageNewParamsWrapper) {
require.Len(t, w.Messages, 1)
require.Equal(t, anthropic.MessageParamRoleUser, w.Messages[0].Role)
text := w.Messages[0].Content[0].GetText()
require.NotNil(t, text)
require.Equal(t, "Hello world", *text)
},
},
{
name: "multiple messages with mixed content",
input: `{
"messages": [
{
"role": "user",
"content": "First message"
},
{
"role": "assistant",
"content": [{"type": "text", "text": "Response"}]
},
{
"role": "user",
"content": "Second message"
}
]
}`,
expected: `{"messages":[{"content":[{"text":"First message","type":"text"}],"role":"user"},{"content":[{"text":"Response","type":"text"}],"role":"assistant"},{"content":[{"text":"Second message","type":"text"}],"role":"user"}]}`,
name: "stream field extracted",
input: `{"model":"claude-3","max_tokens":1000,"stream":true,"messages":[{"role":"user","content":"Hi"}]}`,
expectedStream: true,
checkContent: func(t *testing.T, w *MessageNewParamsWrapper) {
require.Len(t, w.Messages, 1)
},
},
{
name: "tool_result with string content",
input: `{
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "123",
"content": "Tool output"
}
]
}
]
}`,
expected: `{"messages":[{"content":[{"content":[{"text":"Tool output","type":"text"}],"tool_use_id":"123","type":"tool_result"}],"role":"user"}]}`,
name: "stream false",
input: `{"model":"claude-3","max_tokens":1000,"stream":false,"messages":[{"role":"user","content":"Hi"}]}`,
expectedStream: false,
checkContent: nil,
},
{
name: "mcp_tool_result with string content unchanged",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this test case (and others related to tool_result, mcp_tool_result) no longer relevant? We still have logic in convertStringContentRecursive that relates to these types of content.

input: `{
"messages": [
{
"role": "user",
"content": [
{
"type": "mcp_tool_result",
"tool_use_id": "456",
"content": "MCP output"
}
]
}
]
}`,
expected: `{"messages":[{"content":[{"content":"MCP output","tool_use_id":"456","type":"mcp_tool_result"}],"role":"user"}]}`,
name: "array content unchanged",
input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`,
expectedStream: false,
checkContent: func(t *testing.T, w *MessageNewParamsWrapper) {
require.Len(t, w.Messages, 1)
text := w.Messages[0].Content[0].GetText()
require.NotNil(t, text)
require.Equal(t, "Hello", *text)
},
},
{
name: "no messages field",
input: `{
"model": "claude-3",
"max_tokens": 1000
}`,
expected: `{"max_tokens":1000,"model":"claude-3"}`,
name: "multiple messages with mixed content",
input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":"First"},{"role":"assistant","content":[{"type":"text","text":"Response"}]},{"role":"user","content":"Second"}]}`,
expectedStream: false,
checkContent: func(t *testing.T, w *MessageNewParamsWrapper) {
require.Len(t, w.Messages, 3)
text0 := w.Messages[0].Content[0].GetText()
require.NotNil(t, text0)
require.Equal(t, "First", *text0)
text2 := w.Messages[2].Content[0].GetText()
require.NotNil(t, text2)
require.Equal(t, "Second", *text2)
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := convertStringContentToArray([]byte(tt.input))
require.NoError(t, err)

var resultJSON, expectedJSON any
err = json.Unmarshal(result, &resultJSON)
require.NoError(t, err)
err = json.Unmarshal([]byte(tt.expected), &expectedJSON)
var wrapper MessageNewParamsWrapper
err := wrapper.UnmarshalJSON([]byte(tt.input))
require.NoError(t, err)

require.Equal(t, expectedJSON, resultJSON)
require.Equal(t, tt.expectedStream, wrapper.Stream)
if tt.checkContent != nil {
tt.checkContent(t, &wrapper)
}
})
}
}
Expand Down