|
1 | 1 | package messages |
2 | 2 |
|
3 | 3 | import ( |
4 | | - "encoding/json" |
5 | 4 | "testing" |
6 | 5 |
|
7 | 6 | "github.com/anthropics/anthropic-sdk-go" |
8 | 7 | "github.com/stretchr/testify/require" |
9 | 8 | ) |
10 | 9 |
|
11 | | -func TestConvertStringContentToArray(t *testing.T) { |
| 10 | +func TestMessageNewParamsWrapperUnmarshalJSON(t *testing.T) { |
12 | 11 | t.Parallel() |
13 | 12 |
|
14 | 13 | tests := []struct { |
15 | | - name string |
16 | | - input string |
17 | | - expected string |
| 14 | + name string |
| 15 | + input string |
| 16 | + expectedStream bool |
| 17 | + checkContent func(t *testing.T, w *MessageNewParamsWrapper) |
18 | 18 | }{ |
19 | 19 | { |
20 | | - name: "empty json", |
21 | | - input: `{}`, |
22 | | - expected: `{}`, |
23 | | - }, |
24 | | - { |
25 | | - name: "message with string content", |
26 | | - input: `{ |
27 | | - "messages": [ |
28 | | - { |
29 | | - "role": "user", |
30 | | - "content": "Hello world" |
31 | | - } |
32 | | - ] |
33 | | - }`, |
34 | | - expected: `{"messages":[{"content":[{"text":"Hello world","type":"text"}],"role":"user"}]}`, |
35 | | - }, |
36 | | - { |
37 | | - name: "message with array content unchanged", |
38 | | - input: `{ |
39 | | - "messages": [ |
40 | | - { |
41 | | - "role": "user", |
42 | | - "content": [{"type": "text", "text": "Hello"}] |
43 | | - } |
44 | | - ] |
45 | | - }`, |
46 | | - expected: `{"messages":[{"content":[{"text":"Hello","type":"text"}],"role":"user"}]}`, |
| 20 | + name: "message with string content converts to array", |
| 21 | + input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":"Hello world"}]}`, |
| 22 | + expectedStream: false, |
| 23 | + checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { |
| 24 | + require.Len(t, w.Messages, 1) |
| 25 | + require.Equal(t, anthropic.MessageParamRoleUser, w.Messages[0].Role) |
| 26 | + text := w.Messages[0].Content[0].GetText() |
| 27 | + require.NotNil(t, text) |
| 28 | + require.Equal(t, "Hello world", *text) |
| 29 | + }, |
47 | 30 | }, |
48 | 31 | { |
49 | | - name: "multiple messages with mixed content", |
50 | | - input: `{ |
51 | | - "messages": [ |
52 | | - { |
53 | | - "role": "user", |
54 | | - "content": "First message" |
55 | | - }, |
56 | | - { |
57 | | - "role": "assistant", |
58 | | - "content": [{"type": "text", "text": "Response"}] |
59 | | - }, |
60 | | - { |
61 | | - "role": "user", |
62 | | - "content": "Second message" |
63 | | - } |
64 | | - ] |
65 | | - }`, |
66 | | - expected: `{"messages":[{"content":[{"text":"First message","type":"text"}],"role":"user"},{"content":[{"text":"Response","type":"text"}],"role":"assistant"},{"content":[{"text":"Second message","type":"text"}],"role":"user"}]}`, |
| 32 | + name: "stream field extracted", |
| 33 | + input: `{"model":"claude-3","max_tokens":1000,"stream":true,"messages":[{"role":"user","content":"Hi"}]}`, |
| 34 | + expectedStream: true, |
| 35 | + checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { |
| 36 | + require.Len(t, w.Messages, 1) |
| 37 | + }, |
67 | 38 | }, |
68 | 39 | { |
69 | | - name: "tool_result with string content", |
70 | | - input: `{ |
71 | | - "messages": [ |
72 | | - { |
73 | | - "role": "user", |
74 | | - "content": [ |
75 | | - { |
76 | | - "type": "tool_result", |
77 | | - "tool_use_id": "123", |
78 | | - "content": "Tool output" |
79 | | - } |
80 | | - ] |
81 | | - } |
82 | | - ] |
83 | | - }`, |
84 | | - expected: `{"messages":[{"content":[{"content":[{"text":"Tool output","type":"text"}],"tool_use_id":"123","type":"tool_result"}],"role":"user"}]}`, |
| 40 | + name: "stream false", |
| 41 | + input: `{"model":"claude-3","max_tokens":1000,"stream":false,"messages":[{"role":"user","content":"Hi"}]}`, |
| 42 | + expectedStream: false, |
| 43 | + checkContent: nil, |
85 | 44 | }, |
86 | 45 | { |
87 | | - name: "mcp_tool_result with string content unchanged", |
88 | | - input: `{ |
89 | | - "messages": [ |
90 | | - { |
91 | | - "role": "user", |
92 | | - "content": [ |
93 | | - { |
94 | | - "type": "mcp_tool_result", |
95 | | - "tool_use_id": "456", |
96 | | - "content": "MCP output" |
97 | | - } |
98 | | - ] |
99 | | - } |
100 | | - ] |
101 | | - }`, |
102 | | - expected: `{"messages":[{"content":[{"content":"MCP output","tool_use_id":"456","type":"mcp_tool_result"}],"role":"user"}]}`, |
| 46 | + name: "array content unchanged", |
| 47 | + input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`, |
| 48 | + expectedStream: false, |
| 49 | + checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { |
| 50 | + require.Len(t, w.Messages, 1) |
| 51 | + text := w.Messages[0].Content[0].GetText() |
| 52 | + require.NotNil(t, text) |
| 53 | + require.Equal(t, "Hello", *text) |
| 54 | + }, |
103 | 55 | }, |
104 | 56 | { |
105 | | - name: "no messages field", |
106 | | - input: `{ |
107 | | - "model": "claude-3", |
108 | | - "max_tokens": 1000 |
109 | | - }`, |
110 | | - expected: `{"max_tokens":1000,"model":"claude-3"}`, |
| 57 | + name: "multiple messages with mixed content", |
| 58 | + input: `{"model":"claude-3","max_tokens":1000,"messages":[{"role":"user","content":"First"},{"role":"assistant","content":[{"type":"text","text":"Response"}]},{"role":"user","content":"Second"}]}`, |
| 59 | + expectedStream: false, |
| 60 | + checkContent: func(t *testing.T, w *MessageNewParamsWrapper) { |
| 61 | + require.Len(t, w.Messages, 3) |
| 62 | + text0 := w.Messages[0].Content[0].GetText() |
| 63 | + require.NotNil(t, text0) |
| 64 | + require.Equal(t, "First", *text0) |
| 65 | + text2 := w.Messages[2].Content[0].GetText() |
| 66 | + require.NotNil(t, text2) |
| 67 | + require.Equal(t, "Second", *text2) |
| 68 | + }, |
111 | 69 | }, |
112 | 70 | } |
113 | 71 |
|
114 | 72 | for _, tt := range tests { |
115 | 73 | t.Run(tt.name, func(t *testing.T) { |
116 | | - result, err := convertStringContentToArray([]byte(tt.input)) |
117 | | - require.NoError(t, err) |
118 | | - |
119 | | - var resultJSON, expectedJSON any |
120 | | - err = json.Unmarshal(result, &resultJSON) |
121 | | - require.NoError(t, err) |
122 | | - err = json.Unmarshal([]byte(tt.expected), &expectedJSON) |
| 74 | + var wrapper MessageNewParamsWrapper |
| 75 | + err := wrapper.UnmarshalJSON([]byte(tt.input)) |
123 | 76 | require.NoError(t, err) |
124 | | - |
125 | | - require.Equal(t, expectedJSON, resultJSON) |
| 77 | + require.Equal(t, tt.expectedStream, wrapper.Stream) |
| 78 | + if tt.checkContent != nil { |
| 79 | + tt.checkContent(t, &wrapper) |
| 80 | + } |
126 | 81 | }) |
127 | 82 | } |
128 | 83 | } |
|
0 commit comments