Skip to content

Commit c5af652

Browse files
committed
feat: initial implementation
0 parents  commit c5af652

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+7838
-0
lines changed

Makefile

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Use a single bash shell for each job, and immediately exit on failure
2+
.SHELL := /usr/bin/env bash
3+
.SHELLFLAGS := -ceu
4+
.ONESHELL:
5+
6+
# This doesn't work on directories.
7+
# See https://stackoverflow.com/questions/25752543/make-delete-on-error-for-directory-targets
8+
.DELETE_ON_ERROR:
9+
10+
# Don't print the commands in the file unless you specify VERBOSE. This is
11+
# essentially the same as putting "@" at the start of each line.
12+
ifndef VERBOSE
13+
.SILENT:
14+
endif
15+
16+
test:
17+
go test -count=1 ./...
18+
19+
test-race:
20+
CGO_ENABLED=1 go test -count=1 -race ./...
21+
22+
coverage:
23+
go test -coverprofile=coverage.out ./...
24+
go tool cover -func=coverage.out | tail -n 1
25+
26+
coverage-html:
27+
@go test -coverprofile=coverage.out ./...
28+
@go tool cover -html=coverage.out
29+
30+
clean:
31+
rm -f coverage.out
32+
33+
fmt: fmt/go
34+
.PHONY: fmt
35+
36+
fmt/go:
37+
ifdef FILE
38+
# Format single file
39+
if [[ -f "$(FILE)" ]] && [[ "$(FILE)" == *.go ]] && ! grep -q "DO NOT EDIT" "$(FILE)"; then \
40+
go run mvdan.cc/[email protected] -w -l "$(FILE)"; \
41+
fi
42+
else
43+
go mod tidy
44+
find . -type f -name '*.go' -print0 | \
45+
xargs -0 grep -E --null -L '^// Code generated .* DO NOT EDIT\.$$' | \
46+
xargs -0 go run mvdan.cc/[email protected] -w -l
47+
endif
48+
.PHONY: fmt/go
49+
50+
mocks: mcp/api.go
51+
go generate ./mcpmock/

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
aibridge

anthropic.go

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
package aibridge
2+
3+
import (
4+
"encoding/json"
5+
"errors"
6+
"strings"
7+
8+
"github.com/anthropics/anthropic-sdk-go"
9+
"github.com/anthropics/anthropic-sdk-go/packages/param"
10+
"github.com/coder/aibridge/utils"
11+
)
12+
13+
// MessageNewParamsWrapper exists because the "stream" param is not included in anthropic.MessageNewParams.
14+
type MessageNewParamsWrapper struct {
15+
anthropic.MessageNewParams `json:""`
16+
Stream bool `json:"stream,omitempty"`
17+
}
18+
19+
func (b MessageNewParamsWrapper) MarshalJSON() ([]byte, error) {
20+
type shadow MessageNewParamsWrapper
21+
return param.MarshalWithExtras(b, (*shadow)(&b), map[string]any{
22+
"stream": b.Stream,
23+
})
24+
}
25+
26+
func (b *MessageNewParamsWrapper) UnmarshalJSON(raw []byte) error {
27+
convertedRaw, err := convertStringContentToArray(raw)
28+
if err != nil {
29+
return err
30+
}
31+
32+
err = b.MessageNewParams.UnmarshalJSON(convertedRaw)
33+
if err != nil {
34+
return err
35+
}
36+
37+
b.Stream = utils.ExtractJSONField[bool](raw, "stream")
38+
return nil
39+
}
40+
41+
func (b *MessageNewParamsWrapper) UseStreaming() bool {
42+
return b.Stream
43+
}
44+
45+
func (b *MessageNewParamsWrapper) LastUserPrompt() (*string, error) {
46+
if b == nil {
47+
return nil, errors.New("nil struct")
48+
}
49+
50+
if len(b.Messages) == 0 {
51+
return nil, errors.New("no messages")
52+
}
53+
54+
var userMessage string
55+
for i := len(b.Messages) - 1; i >= 0; i-- {
56+
m := b.Messages[i]
57+
if m.Role != anthropic.MessageParamRoleUser {
58+
continue
59+
}
60+
if len(m.Content) == 0 {
61+
continue
62+
}
63+
64+
for j := len(m.Content) - 1; j >= 0; j-- {
65+
if textContent := m.Content[j].GetText(); textContent != nil {
66+
userMessage = *textContent
67+
}
68+
69+
return utils.PtrTo(strings.TrimSpace(userMessage)), nil
70+
}
71+
}
72+
73+
return nil, nil
74+
}
75+
76+
// convertStringContentToArray converts string content to array format for Anthropic messages.
77+
// https://docs.anthropic.com/en/api/messages#body-messages
78+
//
79+
// Each input message content may be either a single string or an array of content blocks, where each block has a
80+
// specific type. Using a string for content is shorthand for an array of one content block of type "text".
81+
func convertStringContentToArray(raw []byte) ([]byte, error) {
82+
var modifiedJSON map[string]any
83+
if err := json.Unmarshal(raw, &modifiedJSON); err != nil {
84+
return raw, err
85+
}
86+
87+
// Check if messages exist and need content conversion
88+
if _, hasMessages := modifiedJSON["messages"]; hasMessages {
89+
convertStringContentRecursive(modifiedJSON)
90+
91+
// Marshal back to JSON
92+
return json.Marshal(modifiedJSON)
93+
}
94+
95+
return raw, nil
96+
}
97+
98+
// convertStringContentRecursive recursively scans JSON data and converts string "content" fields
99+
// to proper text block arrays where needed for Anthropic SDK compatibility
100+
func convertStringContentRecursive(data any) {
101+
switch v := data.(type) {
102+
case map[string]any:
103+
// Check if this object has a "content" field with string value
104+
if content, hasContent := v["content"]; hasContent {
105+
if contentStr, isString := content.(string); isString {
106+
// Check if this needs conversion based on context
107+
if shouldConvertContentField(v) {
108+
v["content"] = []map[string]any{
109+
{
110+
"type": "text",
111+
"text": contentStr,
112+
},
113+
}
114+
}
115+
}
116+
}
117+
118+
// Recursively process all values in the map
119+
for _, value := range v {
120+
convertStringContentRecursive(value)
121+
}
122+
123+
case []any:
124+
// Recursively process all items in the array
125+
for _, item := range v {
126+
convertStringContentRecursive(item)
127+
}
128+
}
129+
}
130+
131+
// shouldConvertContentField determines if a "content" string field should be converted to text block array
132+
func shouldConvertContentField(obj map[string]any) bool {
133+
// Check if this is a message-level content (has "role" field)
134+
if _, hasRole := obj["role"]; hasRole {
135+
return true
136+
}
137+
138+
// Check if this is a tool_result block (but not mcp_tool_result which supports strings)
139+
if objType, hasType := obj["type"].(string); hasType {
140+
switch objType {
141+
case "tool_result":
142+
return true // Regular tool_result needs array format
143+
case "mcp_tool_result":
144+
return false // MCP tool_result supports strings
145+
}
146+
}
147+
148+
return false
149+
}
150+
151+
// accumulateUsage accumulates usage statistics from source into dest.
152+
// It handles both [anthropic.Usage] and [anthropic.MessageDeltaUsage] types through [any].
153+
// The function uses reflection to handle the differences between the types:
154+
// - [anthropic.Usage] has CacheCreation field with ephemeral tokens
155+
// - [anthropic.MessageDeltaUsage] doesn't have CacheCreation field
156+
func accumulateUsage(dest, src any) {
157+
switch d := dest.(type) {
158+
case *anthropic.Usage:
159+
if d == nil {
160+
return
161+
}
162+
switch s := src.(type) {
163+
case anthropic.Usage:
164+
// Usage -> Usage
165+
d.CacheCreation.Ephemeral1hInputTokens += s.CacheCreation.Ephemeral1hInputTokens
166+
d.CacheCreation.Ephemeral5mInputTokens += s.CacheCreation.Ephemeral5mInputTokens
167+
d.CacheCreationInputTokens += s.CacheCreationInputTokens
168+
d.CacheReadInputTokens += s.CacheReadInputTokens
169+
d.InputTokens += s.InputTokens
170+
d.OutputTokens += s.OutputTokens
171+
d.ServerToolUse.WebSearchRequests += s.ServerToolUse.WebSearchRequests
172+
case anthropic.MessageDeltaUsage:
173+
// MessageDeltaUsage -> Usage
174+
d.CacheCreationInputTokens += s.CacheCreationInputTokens
175+
d.CacheReadInputTokens += s.CacheReadInputTokens
176+
d.InputTokens += s.InputTokens
177+
d.OutputTokens += s.OutputTokens
178+
d.ServerToolUse.WebSearchRequests += s.ServerToolUse.WebSearchRequests
179+
}
180+
case *anthropic.MessageDeltaUsage:
181+
if d == nil {
182+
return
183+
}
184+
switch s := src.(type) {
185+
case anthropic.Usage:
186+
// Usage -> MessageDeltaUsage (only common fields)
187+
d.CacheCreationInputTokens += s.CacheCreationInputTokens
188+
d.CacheReadInputTokens += s.CacheReadInputTokens
189+
d.InputTokens += s.InputTokens
190+
d.OutputTokens += s.OutputTokens
191+
d.ServerToolUse.WebSearchRequests += s.ServerToolUse.WebSearchRequests
192+
case anthropic.MessageDeltaUsage:
193+
// MessageDeltaUsage -> MessageDeltaUsage
194+
d.CacheCreationInputTokens += s.CacheCreationInputTokens
195+
d.CacheReadInputTokens += s.CacheReadInputTokens
196+
d.InputTokens += s.InputTokens
197+
d.OutputTokens += s.OutputTokens
198+
d.ServerToolUse.WebSearchRequests += s.ServerToolUse.WebSearchRequests
199+
}
200+
}
201+
}

0 commit comments

Comments
 (0)