Skip to content

Commit 14690bd

Browse files
committed
perf: only marshal once
Signed-off-by: Danny Kopping <[email protected]>
1 parent fab3b9d commit 14690bd

File tree

2 files changed

+81
-129
lines changed

2 files changed

+81
-129
lines changed

intercept/messages/paramswrap.go

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66

77
"github.com/anthropics/anthropic-sdk-go"
88
"github.com/anthropics/anthropic-sdk-go/packages/param"
9-
"github.com/coder/aibridge/utils"
109
)
1110

1211
// MessageNewParamsWrapper exists because the "stream" param is not included in anthropic.MessageNewParams.
@@ -23,18 +22,30 @@ func (b MessageNewParamsWrapper) MarshalJSON() ([]byte, error) {
2322
}
2423

2524
func (b *MessageNewParamsWrapper) UnmarshalJSON(raw []byte) error {
26-
convertedRaw, err := convertStringContentToArray(raw)
27-
if err != nil {
25+
// Parse JSON once and extract both stream field and do content conversion
26+
// to avoid double-parsing the same payload.
27+
var modifiedJSON map[string]any
28+
if err := json.Unmarshal(raw, &modifiedJSON); err != nil {
2829
return err
2930
}
3031

31-
err = b.MessageNewParams.UnmarshalJSON(convertedRaw)
32+
// Extract stream field from already-parsed map
33+
if stream, ok := modifiedJSON["stream"].(bool); ok {
34+
b.Stream = stream
35+
}
36+
37+
// Convert string content to array format if needed
38+
if _, hasMessages := modifiedJSON["messages"]; hasMessages {
39+
convertStringContentRecursive(modifiedJSON)
40+
}
41+
42+
// Marshal back for SDK parsing
43+
convertedRaw, err := json.Marshal(modifiedJSON)
3244
if err != nil {
3345
return err
3446
}
3547

36-
b.Stream = utils.ExtractJSONField[bool](raw, "stream")
37-
return nil
48+
return b.MessageNewParams.UnmarshalJSON(convertedRaw)
3849
}
3950

4051
func (b *MessageNewParamsWrapper) lastUserPrompt() (*string, error) {
@@ -69,31 +80,11 @@ func (b *MessageNewParamsWrapper) lastUserPrompt() (*string, error) {
6980
return nil, nil
7081
}
7182

72-
// convertStringContentToArray converts string content to array format for Anthropic messages.
73-
// https://docs.anthropic.com/en/api/messages#body-messages
74-
//
75-
// Each input message content may be either a single string or an array of content blocks, where each block has a
76-
// specific type. Using a string for content is shorthand for an array of one content block of type "text".
77-
func convertStringContentToArray(raw []byte) ([]byte, error) {
78-
var modifiedJSON map[string]any
79-
if err := json.Unmarshal(raw, &modifiedJSON); err != nil {
80-
return raw, err
81-
}
82-
83-
// Check if messages exist and need content conversion
84-
if _, hasMessages := modifiedJSON["messages"]; hasMessages {
85-
convertStringContentRecursive(modifiedJSON)
86-
87-
// Marshal back to JSON
88-
return json.Marshal(modifiedJSON)
89-
}
90-
91-
return raw, nil
92-
}
93-
9483
// convertStringContentRecursive recursively scans JSON data and converts string "content" fields
95-
// to proper text block arrays where needed for Anthropic SDK compatibility
96-
func convertStringContentRecursive(data any) {
84+
// to proper text block arrays where needed for Anthropic SDK compatibility.
85+
// Returns true if any modifications were made.
86+
func convertStringContentRecursive(data any) bool {
87+
modified := false
9788
switch v := data.(type) {
9889
case map[string]any:
9990
// Check if this object has a "content" field with string value
@@ -107,21 +98,27 @@ func convertStringContentRecursive(data any) {
10798
"text": contentStr,
10899
},
109100
}
101+
modified = true
110102
}
111103
}
112104
}
113105

114106
// Recursively process all values in the map
115107
for _, value := range v {
116-
convertStringContentRecursive(value)
108+
if convertStringContentRecursive(value) {
109+
modified = true
110+
}
117111
}
118112

119113
case []any:
120114
// Recursively process all items in the array
121115
for _, item := range v {
122-
convertStringContentRecursive(item)
116+
if convertStringContentRecursive(item) {
117+
modified = true
118+
}
123119
}
124120
}
121+
return modified
125122
}
126123

127124
// shouldConvertContentField determines if a "content" string field should be converted to text block array

intercept/messages/paramswrap_test.go

Lines changed: 52 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,128 +1,83 @@
11
package messages
22

33
import (
4-
"encoding/json"
54
"testing"
65

76
"github.com/anthropics/anthropic-sdk-go"
87
"github.com/stretchr/testify/require"
98
)
109

11-
func TestConvertStringContentToArray(t *testing.T) {
10+
func TestMessageNewParamsWrapperUnmarshalJSON(t *testing.T) {
1211
t.Parallel()
1312

1413
tests := []struct {
15-
name string
16-
input string
17-
expected string
14+
name string
15+
input string
16+
expectedStream bool
17+
checkContent func(t *testing.T, w *MessageNewParamsWrapper)
1818
}{
1919
{
20-
name: "empty json",
21-
input: `{}`,
22-
expected: `{}`,
23-
},
24-
{
25-
name: "message with string content",
26-
input: `{
27-
"messages": [
28-
{
29-
"role": "user",
30-
"content": "Hello world"
31-
}
32-
]
33-
}`,
34-
expected: `{"messages":[{"content":[{"text":"Hello world","type":"text"}],"role":"user"}]}`,
35-
},
36-
{
37-
name: "message with array content unchanged",
38-
input: `{
39-
"messages": [
40-
{
41-
"role": "user",
42-
"content": [{"type": "text", "text": "Hello"}]
43-
}
44-
]
45-
}`,
46-
expected: `{"messages":[{"content":[{"text":"Hello","type":"text"}],"role":"user"}]}`,
20+
name: "message with string content converts to array",
21+
input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":"Hello world"}]}`,
22+
expectedStream: false,
23+
checkContent: func(t *testing.T, w *MessageNewParamsWrapper) {
24+
require.Len(t, w.Messages, 1)
25+
require.Equal(t, anthropic.MessageParamRoleUser, w.Messages[0].Role)
26+
text := w.Messages[0].Content[0].GetText()
27+
require.NotNil(t, text)
28+
require.Equal(t, "Hello world", *text)
29+
},
4730
},
4831
{
49-
name: "multiple messages with mixed content",
50-
input: `{
51-
"messages": [
52-
{
53-
"role": "user",
54-
"content": "First message"
55-
},
56-
{
57-
"role": "assistant",
58-
"content": [{"type": "text", "text": "Response"}]
59-
},
60-
{
61-
"role": "user",
62-
"content": "Second message"
63-
}
64-
]
65-
}`,
66-
expected: `{"messages":[{"content":[{"text":"First message","type":"text"}],"role":"user"},{"content":[{"text":"Response","type":"text"}],"role":"assistant"},{"content":[{"text":"Second message","type":"text"}],"role":"user"}]}`,
32+
name: "stream field extracted",
33+
input: `{"model":"claude-3","max_tokens":1000,"stream":true,"messages":[{"role":"user","content":"Hi"}]}`,
34+
expectedStream: true,
35+
checkContent: func(t *testing.T, w *MessageNewParamsWrapper) {
36+
require.Len(t, w.Messages, 1)
37+
},
6738
},
6839
{
69-
name: "tool_result with string content",
70-
input: `{
71-
"messages": [
72-
{
73-
"role": "user",
74-
"content": [
75-
{
76-
"type": "tool_result",
77-
"tool_use_id": "123",
78-
"content": "Tool output"
79-
}
80-
]
81-
}
82-
]
83-
}`,
84-
expected: `{"messages":[{"content":[{"content":[{"text":"Tool output","type":"text"}],"tool_use_id":"123","type":"tool_result"}],"role":"user"}]}`,
40+
name: "stream false",
41+
input: `{"model":"claude-3","max_tokens":1000,"stream":false,"messages":[{"role":"user","content":"Hi"}]}`,
42+
expectedStream: false,
43+
checkContent: nil,
8544
},
8645
{
87-
name: "mcp_tool_result with string content unchanged",
88-
input: `{
89-
"messages": [
90-
{
91-
"role": "user",
92-
"content": [
93-
{
94-
"type": "mcp_tool_result",
95-
"tool_use_id": "456",
96-
"content": "MCP output"
97-
}
98-
]
99-
}
100-
]
101-
}`,
102-
expected: `{"messages":[{"content":[{"content":"MCP output","tool_use_id":"456","type":"mcp_tool_result"}],"role":"user"}]}`,
46+
name: "array content unchanged",
47+
input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`,
48+
expectedStream: false,
49+
checkContent: func(t *testing.T, w *MessageNewParamsWrapper) {
50+
require.Len(t, w.Messages, 1)
51+
text := w.Messages[0].Content[0].GetText()
52+
require.NotNil(t, text)
53+
require.Equal(t, "Hello", *text)
54+
},
10355
},
10456
{
105-
name: "no messages field",
106-
input: `{
107-
"model": "claude-3",
108-
"max_tokens": 1000
109-
}`,
110-
expected: `{"max_tokens":1000,"model":"claude-3"}`,
57+
name: "multiple messages with mixed content",
58+
input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":"First"},{"role":"assistant","content":[{"type":"text","text":"Response"}]},{"role":"user","content":"Second"}]}`,
59+
expectedStream: false,
60+
checkContent: func(t *testing.T, w *MessageNewParamsWrapper) {
61+
require.Len(t, w.Messages, 3)
62+
text0 := w.Messages[0].Content[0].GetText()
63+
require.NotNil(t, text0)
64+
require.Equal(t, "First", *text0)
65+
text2 := w.Messages[2].Content[0].GetText()
66+
require.NotNil(t, text2)
67+
require.Equal(t, "Second", *text2)
68+
},
11169
},
11270
}
11371

11472
for _, tt := range tests {
11573
t.Run(tt.name, func(t *testing.T) {
116-
result, err := convertStringContentToArray([]byte(tt.input))
117-
require.NoError(t, err)
118-
119-
var resultJSON, expectedJSON any
120-
err = json.Unmarshal(result, &resultJSON)
121-
require.NoError(t, err)
122-
err = json.Unmarshal([]byte(tt.expected), &expectedJSON)
74+
var wrapper MessageNewParamsWrapper
75+
err := wrapper.UnmarshalJSON([]byte(tt.input))
12376
require.NoError(t, err)
124-
125-
require.Equal(t, expectedJSON, resultJSON)
77+
require.Equal(t, tt.expectedStream, wrapper.Stream)
78+
if tt.checkContent != nil {
79+
tt.checkContent(t, &wrapper)
80+
}
12681
})
12782
}
12883
}

0 commit comments

Comments
 (0)