Skip to content

Commit 068c87e

Browse files
authored
chore: correct last user prompt detection (#10)
Signed-off-by: Danny Kopping <[email protected]>
1 parent 986a7ac commit 068c87e

File tree

5 files changed

+338
-29
lines changed

5 files changed

+338
-29
lines changed

anthropic.go

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package aibridge
33
import (
44
"encoding/json"
55
"errors"
6-
"strings"
76

87
"github.com/anthropics/anthropic-sdk-go"
98
"github.com/anthropics/anthropic-sdk-go/packages/param"
@@ -51,22 +50,23 @@ func (b *MessageNewParamsWrapper) LastUserPrompt() (*string, error) {
5150
return nil, errors.New("no messages")
5251
}
5352

54-
var userMessage string
55-
for i := len(b.Messages) - 1; i >= 0; i-- {
56-
m := b.Messages[i]
57-
if m.Role != anthropic.MessageParamRoleUser {
58-
continue
59-
}
60-
if len(m.Content) == 0 {
61-
continue
62-
}
53+
// We only care if the last message was issued by a user.
54+
msg := b.Messages[len(b.Messages)-1]
55+
if msg.Role != anthropic.MessageParamRoleUser {
56+
return nil, nil
57+
}
6358

64-
for j := len(m.Content) - 1; j >= 0; j-- {
65-
if textContent := m.Content[j].GetText(); textContent != nil {
66-
userMessage = *textContent
67-
}
59+
if len(msg.Content) == 0 {
60+
return nil, nil
61+
}
6862

69-
return utils.PtrTo(strings.TrimSpace(userMessage)), nil
63+
// Walk backwards on "user"-initiated message content. Clients often inject
64+
// content ahead of the actual prompt to provide context to the model,
65+
// so the last item in the slice is most likely the user's prompt.
66+
for i := len(msg.Content) - 1; i >= 0; i-- {
67+
// Only text content is supported currently.
68+
if textContent := msg.Content[i].GetText(); textContent != nil {
69+
return textContent, nil
7070
}
7171
}
7272

anthropic_test.go

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
package aibridge_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/anthropics/anthropic-sdk-go"
7+
"github.com/coder/aibridge"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestAnthropicLastUserPrompt(t *testing.T) {
12+
t.Parallel()
13+
14+
tests := []struct {
15+
name string
16+
wrapper *aibridge.MessageNewParamsWrapper
17+
expected string
18+
expectError bool
19+
errorMsg string
20+
}{
21+
{
22+
name: "nil struct",
23+
expectError: true,
24+
errorMsg: "nil struct",
25+
},
26+
{
27+
name: "no messages",
28+
wrapper: &aibridge.MessageNewParamsWrapper{
29+
MessageNewParams: anthropic.MessageNewParams{
30+
Messages: []anthropic.MessageParam{},
31+
},
32+
},
33+
expectError: true,
34+
errorMsg: "no messages",
35+
},
36+
{
37+
name: "last message not from user",
38+
wrapper: &aibridge.MessageNewParamsWrapper{
39+
MessageNewParams: anthropic.MessageNewParams{
40+
Messages: []anthropic.MessageParam{
41+
{
42+
Role: anthropic.MessageParamRoleUser,
43+
Content: []anthropic.ContentBlockParamUnion{
44+
anthropic.NewTextBlock("user message"),
45+
},
46+
},
47+
{
48+
Role: anthropic.MessageParamRoleAssistant,
49+
Content: []anthropic.ContentBlockParamUnion{
50+
anthropic.NewTextBlock("assistant message"),
51+
},
52+
},
53+
},
54+
},
55+
},
56+
},
57+
{
58+
name: "last user message with empty content",
59+
wrapper: &aibridge.MessageNewParamsWrapper{
60+
MessageNewParams: anthropic.MessageNewParams{
61+
Messages: []anthropic.MessageParam{
62+
{
63+
Role: anthropic.MessageParamRoleUser,
64+
Content: []anthropic.ContentBlockParamUnion{},
65+
},
66+
},
67+
},
68+
},
69+
},
70+
{
71+
name: "last user message with single text content",
72+
wrapper: &aibridge.MessageNewParamsWrapper{
73+
MessageNewParams: anthropic.MessageNewParams{
74+
Messages: []anthropic.MessageParam{
75+
{
76+
Role: anthropic.MessageParamRoleUser,
77+
Content: []anthropic.ContentBlockParamUnion{
78+
anthropic.NewTextBlock("Hello, world!"),
79+
},
80+
},
81+
},
82+
},
83+
},
84+
expected: "Hello, world!",
85+
},
86+
{
87+
name: "last user message with multiple content blocks - text at end",
88+
wrapper: &aibridge.MessageNewParamsWrapper{
89+
MessageNewParams: anthropic.MessageNewParams{
90+
Messages: []anthropic.MessageParam{
91+
{
92+
Role: anthropic.MessageParamRoleUser,
93+
Content: []anthropic.ContentBlockParamUnion{
94+
anthropic.NewImageBlockBase64("image/png", "base64data"),
95+
anthropic.NewTextBlock("First text"),
96+
anthropic.NewImageBlockBase64("image/jpeg", "moredata"),
97+
anthropic.NewTextBlock("Last text"),
98+
},
99+
},
100+
},
101+
},
102+
},
103+
expected: "Last text",
104+
},
105+
{
106+
name: "last user message with only non-text content",
107+
wrapper: &aibridge.MessageNewParamsWrapper{
108+
MessageNewParams: anthropic.MessageNewParams{
109+
Messages: []anthropic.MessageParam{
110+
{
111+
Role: anthropic.MessageParamRoleUser,
112+
Content: []anthropic.ContentBlockParamUnion{
113+
anthropic.NewImageBlockBase64("image/png", "base64data"),
114+
anthropic.NewImageBlockBase64("image/jpeg", "moredata"),
115+
},
116+
},
117+
},
118+
},
119+
},
120+
},
121+
{
122+
name: "multiple messages with last being user",
123+
wrapper: &aibridge.MessageNewParamsWrapper{
124+
MessageNewParams: anthropic.MessageNewParams{
125+
Messages: []anthropic.MessageParam{
126+
{
127+
Role: anthropic.MessageParamRoleUser,
128+
Content: []anthropic.ContentBlockParamUnion{
129+
anthropic.NewTextBlock("First user message"),
130+
},
131+
},
132+
{
133+
Role: anthropic.MessageParamRoleAssistant,
134+
Content: []anthropic.ContentBlockParamUnion{
135+
anthropic.NewTextBlock("Assistant response"),
136+
},
137+
},
138+
{
139+
Role: anthropic.MessageParamRoleUser,
140+
Content: []anthropic.ContentBlockParamUnion{
141+
anthropic.NewTextBlock("Second user message"),
142+
},
143+
},
144+
},
145+
},
146+
},
147+
expected: "Second user message",
148+
},
149+
}
150+
151+
for _, tt := range tests {
152+
t.Run(tt.name, func(t *testing.T) {
153+
result, err := tt.wrapper.LastUserPrompt()
154+
155+
if tt.expectError {
156+
require.Error(t, err)
157+
require.Contains(t, err.Error(), tt.errorMsg)
158+
require.Nil(t, result)
159+
} else {
160+
require.NoError(t, err)
161+
// Check pointer equality - both nil or both non-nil
162+
if tt.expected == "" {
163+
require.Nil(t, result)
164+
} else {
165+
require.NotNil(t, result)
166+
// The result should point to the same string from the content block
167+
require.Equal(t, tt.expected, *result)
168+
}
169+
}
170+
})
171+
}
172+
}

bridge_integration_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ func TestSimple(t *testing.T) {
426426

427427
// Then: I expect the prompt to have been tracked.
428428
require.NotEmpty(t, recorderClient.userPrompts, "no prompts tracked")
429-
assert.Equal(t, "how many angels can dance on the head of a pin", recorderClient.userPrompts[0].Prompt)
429+
assert.Contains(t, recorderClient.userPrompts[0].Prompt, "how many angels can dance on the head of a pin")
430430

431431
// Validate that responses have their IDs overridden with a interception ID rather than the original ID from the upstream provider.
432432
// The reason for this is that Bridge may make multiple upstream requests (i.e. to invoke injected tools), and clients will not be expecting

openai.go

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package aibridge
33
import (
44
"encoding/json"
55
"errors"
6-
"strings"
76

87
"github.com/anthropics/anthropic-sdk-go/shared"
98
"github.com/anthropics/anthropic-sdk-go/shared/constant"
@@ -56,22 +55,27 @@ func (c *ChatCompletionNewParamsWrapper) LastUserPrompt() (*string, error) {
5655
return nil, errors.New("no messages")
5756
}
5857

59-
var msg *openai.ChatCompletionUserMessageParam
60-
for i := len(c.Messages) - 1; i >= 0; i-- {
61-
m := c.Messages[i]
62-
if m.OfUser != nil {
63-
msg = m.OfUser
64-
break
65-
}
58+
// We only care if the last message was issued by a user.
59+
msg := c.Messages[len(c.Messages)-1]
60+
if msg.OfUser == nil {
61+
return nil, nil
6662
}
6763

68-
if msg == nil {
69-
return nil, nil
64+
if msg.OfUser.Content.OfString.String() != "" {
65+
return utils.PtrTo(msg.OfUser.Content.OfString.String()), nil
66+
}
67+
68+
// Walk backwards on "user"-initiated message content. Clients often inject
69+
// content ahead of the actual prompt to provide context to the model,
70+
// so the last item in the slice is most likely the user's prompt.
71+
for i := len(msg.OfUser.Content.OfArrayOfContentParts) - 1; i >= 0; i-- {
72+
// Only text content is supported currently.
73+
if textContent := msg.OfUser.Content.OfArrayOfContentParts[i].OfText; textContent != nil {
74+
return &textContent.Text, nil
75+
}
7076
}
7177

72-
return utils.PtrTo(strings.TrimSpace(
73-
msg.Content.OfString.String(),
74-
)), nil
78+
return nil, nil
7579
}
7680

7781
func sumUsage(ref, in openai.CompletionUsage) openai.CompletionUsage {

0 commit comments

Comments
 (0)