Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"net"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
Expand All @@ -25,6 +27,7 @@ import (
"github.com/coder/aibridge/config"
aibcontext "github.com/coder/aibridge/context"
"github.com/coder/aibridge/fixtures"
"github.com/coder/aibridge/intercept/apidump"
"github.com/coder/aibridge/internal/testutil"
"github.com/coder/aibridge/mcp"
"github.com/coder/aibridge/provider"
Expand Down Expand Up @@ -1851,9 +1854,140 @@ func openaiCfg(url, key string) config.OpenAI {
}
}

func openaiCfgWithAPIDump(url, key, dumpDir string) config.OpenAI {
return config.OpenAI{
BaseURL: url,
Key: key,
APIDumpDir: dumpDir,
}
}

func anthropicCfg(url, key string) config.Anthropic {
return config.Anthropic{
BaseURL: url,
Key: key,
}
}

func anthropicCfgWithAPIDump(url, key, dumpDir string) config.Anthropic {
return config.Anthropic{
BaseURL: url,
Key: key,
APIDumpDir: dumpDir,
}
}

func TestAPIDump(t *testing.T) {
t.Parallel()

cases := []struct {
name string
fixture []byte
providerName string
providersFunc func(addr, dumpDir string) []aibridge.Provider
createRequestFunc createRequestFunc
}{
{
name: config.ProviderAnthropic,
fixture: fixtures.AntSimple,
providerName: config.ProviderAnthropic,
providersFunc: func(addr, dumpDir string) []aibridge.Provider {
return []aibridge.Provider{provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil)}
},
createRequestFunc: createAnthropicMessagesReq,
},
{
name: config.ProviderOpenAI,
fixture: fixtures.OaiChatSimple,
providerName: config.ProviderOpenAI,
providersFunc: func(addr, dumpDir string) []aibridge.Provider {
return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))}
},
createRequestFunc: createOpenAIChatCompletionsReq,
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)

ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
t.Cleanup(cancel)

arc := txtar.Parse(tc.fixture)
files := filesMap(arc)
require.Contains(t, files, fixtureRequest)
require.Contains(t, files, fixtureNonStreamingResponse)

reqBody := files[fixtureRequest]

// Setup mock upstream server.
srv := newMockServer(ctx, t, files, nil)
t.Cleanup(srv.Close)

// Create temp dir for API dumps.
dumpDir := t.TempDir()

recorderClient := &testutil.MockRecorder{}
b, err := aibridge.NewRequestBridge(t.Context(), tc.providersFunc(srv.URL, dumpDir), recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer)
require.NoError(t, err)

mockSrv := httptest.NewUnstartedServer(b)
t.Cleanup(mockSrv.Close)
mockSrv.Config.BaseContext = func(_ net.Listener) context.Context {
return aibcontext.AsActor(ctx, userID, nil)
}
mockSrv.Start()

req := tc.createRequestFunc(t, mockSrv.URL, reqBody)
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
defer resp.Body.Close()
_, _ = io.ReadAll(resp.Body)

// Verify dump files were created.
interceptions := recorderClient.RecordedInterceptions()
require.Len(t, interceptions, 1)
interceptionID := interceptions[0].ID

// Find dump files for this interception by walking the dump directory.
var reqDumpFile, respDumpFile string
err = filepath.Walk(dumpDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if info.IsDir() {
return nil
}
// Files are named: {timestamp}-{interceptionID}.{req|resp}.txt
if strings.Contains(path, interceptionID) {
if strings.HasSuffix(path, apidump.SuffixRequest) {
reqDumpFile = path
} else if strings.HasSuffix(path, apidump.SuffixResponse) {
respDumpFile = path
}
}
return nil
})
require.NoError(t, err)
require.NotEmpty(t, reqDumpFile, "request dump file should exist")
require.NotEmpty(t, respDumpFile, "response dump file should exist")

// Verify request dump contains expected HTTP request format.
reqDumpData, err := os.ReadFile(reqDumpFile)
require.NoError(t, err)
require.Contains(t, string(reqDumpData), "POST ")
require.Contains(t, string(reqDumpData), "Host:")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add check like require.Contains(t, string(reqDumpData), string(reqBody))

and similar for response


// Verify response dump contains expected HTTP response format.
respDumpData, err := os.ReadFile(respDumpFile)
require.NoError(t, err)
require.Contains(t, string(respDumpData), "200 OK")

recorderClient.VerifyAllInterceptionsEnded(t)
})
}
}
2 changes: 2 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func DefaultCircuitBreaker() CircuitBreaker {
type Anthropic struct {
BaseURL string
Key string
APIDumpDir string
CircuitBreaker *CircuitBreaker
}

Expand All @@ -53,5 +54,6 @@ type AWSBedrock struct {
type OpenAI struct {
BaseURL string
Key string
APIDumpDir string
CircuitBreaker *CircuitBreaker
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.24.6
// Misc libs.
require (
cdr.dev/slog/v3 v3.0.0-rc1
github.com/coder/quartz v0.3.0
github.com/google/uuid v1.6.0
github.com/hashicorp/go-multierror v1.1.1
github.com/mark3labs/mcp-go v0.38.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/charmbracelet/lipgloss v0.7.1 h1:17WMwi7N1b1rVWOjMT+rCh7sQkvDU75B2hbZpc5Kc1E=
github.com/charmbracelet/lipgloss v0.7.1/go.mod h1:yG0k3giv8Qj8edTCbbg6AlQ5e8KNWpFujkNawKNhE2c=
github.com/coder/quartz v0.3.0 h1:bUoSEJ77NBfKtUqv6CPSC0AS8dsjqAqqAv7bN02m1mg=
github.com/coder/quartz v0.3.0/go.mod h1:BgE7DOj/8NfvRgvKw0jPLDQH/2Lya2kxcTaNJ8X0rZk=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM=
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
Expand Down
161 changes: 161 additions & 0 deletions intercept/apidump/apidump.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package apidump

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"

"github.com/coder/quartz"
"github.com/google/uuid"
)

const (
// SuffixRequest is the file suffix for request dump files.
SuffixRequest = ".req.txt"
// SuffixResponse is the file suffix for response dump files.
SuffixResponse = ".resp.txt"
)

// MiddlewareNext is the function to call the next middleware or the actual request.
type MiddlewareNext = func(*http.Request) (*http.Response, error)

// Middleware is an HTTP middleware function compatible with SDK WithMiddleware options.
type Middleware = func(*http.Request, MiddlewareNext) (*http.Response, error)

// NewMiddleware returns a middleware function that dumps requests and responses to files.
// Files are written to the path returned by DumpPath.
// If baseDir is empty, returns nil (no middleware).
func NewMiddleware(baseDir, provider, model string, interceptionID uuid.UUID, clk quartz.Clock) Middleware {
if baseDir == "" {
return nil
}

d := &dumper{
baseDir: baseDir,
provider: provider,
model: model,
interceptionID: interceptionID,
clk: clk,
}

return func(req *http.Request, next MiddlewareNext) (*http.Response, error) {
if err := d.dumpRequest(req); err != nil {
fmt.Fprintf(os.Stderr, "apidump: failed to dump request: %v\n", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe logger?

}

resp, err := next(req)
if err != nil {
return resp, err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about it but maybe it be useful to have "err" file with error?

}

if err := d.dumpResponse(resp); err != nil {
fmt.Fprintf(os.Stderr, "apidump: failed to dump response: %v\n", err)
}

return resp, nil
}
}

type dumper struct {
baseDir string
provider string
model string
interceptionID uuid.UUID
clk quartz.Clock
}

func (d *dumper) dumpRequest(req *http.Request) error {
dumpPath := d.path(SuffixRequest)
if err := os.MkdirAll(filepath.Dir(dumpPath), 0o755); err != nil {
return fmt.Errorf("create dump dir: %w", err)
}

// Read and restore body
var bodyBytes []byte
if req.Body != nil {
var err error
bodyBytes, err = io.ReadAll(req.Body)
if err != nil {
return fmt.Errorf("read request body: %w", err)
}
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
}

// Build raw HTTP request format
var buf bytes.Buffer
fmt.Fprintf(&buf, "%s %s %s\r\n", req.Method, req.URL.RequestURI(), req.Proto)
fmt.Fprintf(&buf, "Host: %s\r\n", req.Host)
for key, values := range req.Header {
_, sensitive := sensitiveRequestHeaders[key]
for _, value := range values {
if sensitive {
value = redactHeaderValue(value)
}
fmt.Fprintf(&buf, "%s: %s\r\n", key, value)
}
}
fmt.Fprintf(&buf, "\r\n")
buf.Write(prettyPrintJSON(bodyBytes))

return os.WriteFile(dumpPath, buf.Bytes(), 0o644)
}

func (d *dumper) dumpResponse(resp *http.Response) error {
dumpPath := d.path(SuffixResponse)

// Build raw HTTP response headers
var headerBuf bytes.Buffer
fmt.Fprintf(&headerBuf, "%s %s\r\n", resp.Proto, resp.Status)
for key, values := range resp.Header {
_, sensitive := sensitiveResponseHeaders[key]
for _, value := range values {
if sensitive {
value = redactHeaderValue(value)
}
fmt.Fprintf(&headerBuf, "%s: %s\r\n", key, value)
}
}
fmt.Fprintf(&headerBuf, "\r\n")

// Wrap the response body to capture it as it streams
if resp.Body != nil {
resp.Body = &streamingBodyDumper{
body: resp.Body,
dumpPath: dumpPath,
headerData: headerBuf.Bytes(),
}
} else {
// No body, just write headers
return os.WriteFile(dumpPath, headerBuf.Bytes(), 0o644)
}

return nil
}

// path returns the path to a request/response dump file for a given interception.
// suffix should be SuffixRequest or SuffixResponse.
func (d *dumper) path(suffix string) string {
safeModel := strings.ReplaceAll(d.model, "/", "-")
return filepath.Join(d.baseDir, d.provider, safeModel, fmt.Sprintf("%d-%s%s", d.clk.Now().UTC().UnixMilli(), d.interceptionID, suffix))
}

// prettyPrintJSON returns indented JSON if body is valid JSON, otherwise returns body as-is.
func prettyPrintJSON(body []byte) []byte {
if len(body) == 0 {
return body
}
var parsed any
if err := json.Unmarshal(body, &parsed); err != nil {
return body
}
pretty, err := json.MarshalIndent(parsed, "", " ")
if err != nil {
return body
}
return pretty
}
Loading