diff --git a/apidump_integration_test.go b/apidump_integration_test.go new file mode 100644 index 0000000..07ce52d --- /dev/null +++ b/apidump_integration_test.go @@ -0,0 +1,177 @@ +package aibridge_test + +import ( + "bufio" + "bytes" + "context" + "io" + "net" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/aibridge" + "github.com/coder/aibridge/config" + aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/fixtures" + "github.com/coder/aibridge/intercept/apidump" + "github.com/coder/aibridge/internal/testutil" + "github.com/coder/aibridge/mcp" + "github.com/coder/aibridge/provider" + "github.com/stretchr/testify/require" + "golang.org/x/tools/txtar" +) + +func openaiCfgWithAPIDump(url, key, dumpDir string) config.OpenAI { + return config.OpenAI{ + BaseURL: url, + Key: key, + APIDumpDir: dumpDir, + } +} + +func anthropicCfgWithAPIDump(url, key, dumpDir string) config.Anthropic { + return config.Anthropic{ + BaseURL: url, + Key: key, + APIDumpDir: dumpDir, + } +} + +func TestAPIDump(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + fixture []byte + providerName string + providersFunc func(addr, dumpDir string) []aibridge.Provider + createRequestFunc createRequestFunc + }{ + { + name: config.ProviderAnthropic, + fixture: fixtures.AntSimple, + providerName: config.ProviderAnthropic, + providersFunc: func(addr, dumpDir string) []aibridge.Provider { + return []aibridge.Provider{provider.NewAnthropic(anthropicCfgWithAPIDump(addr, apiKey, dumpDir), nil)} + }, + createRequestFunc: createAnthropicMessagesReq, + }, + { + name: config.ProviderOpenAI, + fixture: fixtures.OaiChatSimple, + providerName: config.ProviderOpenAI, + providersFunc: func(addr, dumpDir string) []aibridge.Provider { + return []aibridge.Provider{provider.NewOpenAI(openaiCfgWithAPIDump(addr, apiKey, dumpDir))} + }, + createRequestFunc: createOpenAIChatCompletionsReq, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + + ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) + t.Cleanup(cancel) + + arc := txtar.Parse(tc.fixture) + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + require.Contains(t, files, fixtureNonStreamingResponse) + + reqBody := files[fixtureRequest] + + // Setup mock upstream server. + srv := newMockServer(ctx, t, files, nil) + t.Cleanup(srv.Close) + + // Create temp dir for API dumps. + dumpDir := t.TempDir() + + recorderClient := &testutil.MockRecorder{} + b, err := aibridge.NewRequestBridge(t.Context(), tc.providersFunc(srv.URL, dumpDir), recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) + require.NoError(t, err) + + mockSrv := httptest.NewUnstartedServer(b) + t.Cleanup(mockSrv.Close) + mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { + return aibcontext.AsActor(ctx, userID, nil) + } + mockSrv.Start() + + req := tc.createRequestFunc(t, mockSrv.URL, reqBody) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + defer resp.Body.Close() + _, _ = io.ReadAll(resp.Body) + + // Verify dump files were created. + interceptions := recorderClient.RecordedInterceptions() + require.Len(t, interceptions, 1) + interceptionID := interceptions[0].ID + + // Find dump files for this interception by walking the dump directory. + var reqDumpFile, respDumpFile string + err = filepath.Walk(dumpDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + // Files are named: {timestamp}-{interceptionID}.{req|resp}.txt + if strings.Contains(path, interceptionID) { + if strings.HasSuffix(path, apidump.SuffixRequest) { + reqDumpFile = path + } else if strings.HasSuffix(path, apidump.SuffixResponse) { + respDumpFile = path + } + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, reqDumpFile, "request dump file should exist") + require.NotEmpty(t, respDumpFile, "response dump file should exist") + + // Verify request dump contains expected HTTP request format. + reqDumpData, err := os.ReadFile(reqDumpFile) + require.NoError(t, err) + + // Parse the dumped HTTP request. + dumpReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(reqDumpData))) + require.NoError(t, err) + dumpBody, err := io.ReadAll(dumpReq.Body) + require.NoError(t, err) + + // Compare requests semantically (key order may differ). + require.JSONEq(t, string(dumpBody), string(reqBody), "request body JSON should match semantically") + + // Verify response dump contains expected HTTP response format. + respDumpData, err := os.ReadFile(respDumpFile) + require.NoError(t, err) + + // Parse the dumped HTTP response. + dumpResp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respDumpData)), nil) + require.NoError(t, err) + require.Equal(t, http.StatusOK, dumpResp.StatusCode) + dumpRespBody, err := io.ReadAll(dumpResp.Body) + require.NoError(t, err) + + // Compare responses semantically (key order may differ). + expectedRespBody := files[fixtureNonStreamingResponse] + require.JSONEq(t, string(expectedRespBody), string(dumpRespBody), "response body JSON should match semantically") + + recorderClient.VerifyAllInterceptionsEnded(t) + }) + } +} diff --git a/config/config.go b/config/config.go index 5b95398..74209c6 100644 --- a/config/config.go +++ b/config/config.go @@ -38,6 +38,7 @@ func DefaultCircuitBreaker() CircuitBreaker { type Anthropic struct { BaseURL string Key string + APIDumpDir string CircuitBreaker *CircuitBreaker } @@ -53,5 +54,6 @@ type AWSBedrock struct { type OpenAI struct { BaseURL string Key string + APIDumpDir string CircuitBreaker *CircuitBreaker } diff --git a/go.mod b/go.mod index 829d837..ec46988 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.24.6 // Misc libs. require ( cdr.dev/slog/v3 v3.0.0-rc1 + github.com/coder/quartz v0.3.0 github.com/google/uuid v1.6.0 github.com/hashicorp/go-multierror v1.1.1 github.com/mark3labs/mcp-go v0.38.0 @@ -12,6 +13,7 @@ require ( github.com/sony/gobreaker/v2 v2.3.0 github.com/stretchr/testify v1.11.1 github.com/tidwall/gjson v1.18.0 + github.com/tidwall/pretty v1.2.1 github.com/tidwall/sjson v1.2.5 go.uber.org/goleak v1.3.0 go.uber.org/mock v0.6.0 @@ -73,7 +75,6 @@ require ( github.com/rivo/uniseg v0.4.4 // indirect github.com/spf13/cast v1.7.1 // indirect github.com/tidwall/match v1.2.0 // indirect - github.com/tidwall/pretty v1.2.1 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect diff --git a/go.sum b/go.sum index 64d8e16..57b38d7 100644 --- a/go.sum +++ b/go.sum @@ -49,6 +49,8 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/charmbracelet/lipgloss v0.7.1 h1:17WMwi7N1b1rVWOjMT+rCh7sQkvDU75B2hbZpc5Kc1E= github.com/charmbracelet/lipgloss v0.7.1/go.mod h1:yG0k3giv8Qj8edTCbbg6AlQ5e8KNWpFujkNawKNhE2c= +github.com/coder/quartz v0.3.0 h1:bUoSEJ77NBfKtUqv6CPSC0AS8dsjqAqqAv7bN02m1mg= +github.com/coder/quartz v0.3.0/go.mod h1:BgE7DOj/8NfvRgvKw0jPLDQH/2Lya2kxcTaNJ8X0rZk= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= diff --git a/intercept/apidump/apidump.go b/intercept/apidump/apidump.go new file mode 100644 index 0000000..a0f3bce --- /dev/null +++ b/intercept/apidump/apidump.go @@ -0,0 +1,195 @@ +package apidump + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "slices" + "strings" + + "cdr.dev/slog/v3" + + "github.com/coder/quartz" + "github.com/google/uuid" + "github.com/tidwall/pretty" +) + +const ( + // SuffixRequest is the file suffix for request dump files. + SuffixRequest = ".req.txt" + // SuffixResponse is the file suffix for response dump files. + SuffixResponse = ".resp.txt" +) + +// MiddlewareNext is the function to call the next middleware or the actual request. +type MiddlewareNext = func(*http.Request) (*http.Response, error) + +// Middleware is an HTTP middleware function compatible with SDK WithMiddleware options. +type Middleware = func(*http.Request, MiddlewareNext) (*http.Response, error) + +// NewMiddleware returns a middleware function that dumps requests and responses to files. +// Files are written to the path returned by DumpPath. +// If baseDir is empty, returns nil (no middleware). +func NewMiddleware(baseDir, provider, model string, interceptionID uuid.UUID, logger slog.Logger, clk quartz.Clock) Middleware { + if baseDir == "" { + return nil + } + + d := &dumper{ + baseDir: baseDir, + provider: provider, + model: model, + interceptionID: interceptionID, + clk: clk, + logger: logger, + } + + return func(req *http.Request, next MiddlewareNext) (*http.Response, error) { + if err := d.dumpRequest(req); err != nil { + logger.Named("apidump").Warn(context.Background(), "failed to dump request", slog.Error(err)) + } + + resp, err := next(req) + if err != nil { + return resp, err + } + + if err := d.dumpResponse(resp); err != nil { + logger.Named("apidump").Warn(context.Background(), "failed to dump response", slog.Error(err)) + } + + return resp, nil + } +} + +type dumper struct { + baseDir string + provider string + model string + interceptionID uuid.UUID + clk quartz.Clock + logger slog.Logger +} + +func (d *dumper) dumpRequest(req *http.Request) error { + dumpPath := d.path(SuffixRequest) + if err := os.MkdirAll(filepath.Dir(dumpPath), 0o755); err != nil { + return fmt.Errorf("create dump dir: %w", err) + } + + // Read and restore body + var bodyBytes []byte + if req.Body != nil { + var err error + bodyBytes, err = io.ReadAll(req.Body) + if err != nil { + return fmt.Errorf("read request body: %w", err) + } + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + } + + prettyBody := prettyPrintJSON(bodyBytes) + + // Build raw HTTP request format + var buf bytes.Buffer + fmt.Fprintf(&buf, "%s %s %s\r\n", req.Method, req.URL.RequestURI(), req.Proto) + fmt.Fprintf(&buf, "Host: %s\r\n", req.Host) + fmt.Fprintf(&buf, "Content-Length: %d\r\n", len(prettyBody)) + + // Sort header keys for deterministic output. + headerKeys := make([]string, 0, len(req.Header)) + for key := range req.Header { + headerKeys = append(headerKeys, key) + } + slices.Sort(headerKeys) + + for _, key := range headerKeys { + // Skip Content-Length since we write it explicitly above with the pretty-printed body length. + if key == "Content-Length" { + continue + } + _, sensitive := sensitiveRequestHeaders[key] + for _, value := range req.Header[key] { + if sensitive { + value = redactHeaderValue(value) + } + fmt.Fprintf(&buf, "%s: %s\r\n", key, value) + } + } + fmt.Fprintf(&buf, "\r\n") + buf.Write(prettyBody) + + return os.WriteFile(dumpPath, buf.Bytes(), 0o644) +} + +func (d *dumper) dumpResponse(resp *http.Response) error { + dumpPath := d.path(SuffixResponse) + + // Build raw HTTP response headers + var headerBuf bytes.Buffer + fmt.Fprintf(&headerBuf, "%s %s\r\n", resp.Proto, resp.Status) + + // Sort header keys for deterministic output. + headerKeys := make([]string, 0, len(resp.Header)) + for key := range resp.Header { + headerKeys = append(headerKeys, key) + } + slices.Sort(headerKeys) + + for _, key := range headerKeys { + _, sensitive := sensitiveResponseHeaders[key] + for _, value := range resp.Header[key] { + if sensitive { + value = redactHeaderValue(value) + } + fmt.Fprintf(&headerBuf, "%s: %s\r\n", key, value) + } + } + fmt.Fprintf(&headerBuf, "\r\n") + + // Wrap the response body to capture it as it streams + if resp.Body != nil { + resp.Body = &streamingBodyDumper{ + body: resp.Body, + dumpPath: dumpPath, + headerData: headerBuf.Bytes(), + logger: func(err error) { + d.logger.Named("apidump").Warn(context.Background(), "failed to initialize response dump", slog.Error(err)) + }, + } + } else { + // No body, just write headers + return os.WriteFile(dumpPath, headerBuf.Bytes(), 0o644) + } + + return nil +} + +// path returns the path to a request/response dump file for a given interception. +// suffix should be SuffixRequest or SuffixResponse. +func (d *dumper) path(suffix string) string { + safeModel := strings.ReplaceAll(d.model, "/", "-") + return filepath.Join(d.baseDir, d.provider, safeModel, fmt.Sprintf("%d-%s%s", d.clk.Now().UTC().UnixMilli(), d.interceptionID, suffix)) +} + +// prettyPrintJSON returns indented JSON if body is valid JSON, otherwise returns body as-is. +// Unlike json.MarshalIndent, this preserves the original key order from the input, +// which makes the dumps easier to read and compare with the original requests. +func prettyPrintJSON(body []byte) []byte { + if len(body) == 0 { + return body + } + result := pretty.Pretty(body) + // pretty.Pretty returns a truncated/modified result for invalid JSON, + // so check if the result is valid JSON; if not, return the original. + if !json.Valid(result) { + return body + } + // Trim trailing newline added by pretty.Pretty. + return bytes.TrimSuffix(result, []byte("\n")) +} diff --git a/intercept/apidump/apidump_test.go b/intercept/apidump/apidump_test.go new file mode 100644 index 0000000..a7b4c5b --- /dev/null +++ b/intercept/apidump/apidump_test.go @@ -0,0 +1,313 @@ +package apidump + +import ( + "bytes" + "io" + "net/http" + "os" + "path/filepath" + "testing" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/quartz" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +// findDumpFile finds a dump file matching the pattern in the given directory. +func findDumpFile(t *testing.T, dir, suffix string) string { + t.Helper() + pattern := filepath.Join(dir, "*"+suffix) + matches, err := filepath.Glob(pattern) + require.NoError(t, err) + require.Len(t, matches, 1, "expected exactly one %s file in %s", suffix, dir) + return matches[0] +} + +func TestMiddleware_RedactsSensitiveRequestHeaders(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + clk := quartz.NewMock(t) + interceptionID := uuid.New() + + middleware := NewMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) + require.NotNil(t, middleware) + + req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{"test": true}`))) + require.NoError(t, err) + + // Add sensitive headers that should be redacted + req.Header.Set("Authorization", "Bearer sk-secret-key-12345") + req.Header.Set("X-Api-Key", "secret-api-key-value") + req.Header.Set("Cookie", "session=abc123") + + // Add non-sensitive headers that should be kept as-is + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "test-client") + + // Call middleware with a mock next function + _, err = middleware(req, func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Proto: "HTTP/1.1", + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(bytes.NewReader([]byte(`{"ok": true}`))), + }, nil + }) + require.NoError(t, err) + + // Read the request dump file + modelDir := filepath.Join(tmpDir, "openai", "gpt-4") + reqDumpPath := findDumpFile(t, modelDir, SuffixRequest) + reqContent, err := os.ReadFile(reqDumpPath) + require.NoError(t, err) + + content := string(reqContent) + + // Verify sensitive headers ARE present but redacted + require.Contains(t, content, "Authorization: Bear...2345") + require.Contains(t, content, "X-Api-Key: secr...alue") + require.Contains(t, content, "Cookie: sess...c123") // "session=abc123" is 14 chars, so first 4 + last 4 + + // Verify the full secret values are NOT present + require.NotContains(t, content, "sk-secret-key-12345") + require.NotContains(t, content, "secret-api-key-value") + + // Verify non-sensitive headers ARE present in full + require.Contains(t, content, "Content-Type: application/json") + require.Contains(t, content, "User-Agent: test-client") +} + +func TestMiddleware_RedactsSensitiveResponseHeaders(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + clk := quartz.NewMock(t) + interceptionID := uuid.New() + + middleware := NewMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) + require.NotNil(t, middleware) + + req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + require.NoError(t, err) + + // Call middleware with a response containing sensitive headers + resp, err := middleware(req, func(r *http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Proto: "HTTP/1.1", + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte(`{"ok": true}`))), + } + // Add sensitive response headers + resp.Header.Set("Set-Cookie", "session=secret123; HttpOnly; Secure") + resp.Header.Set("WWW-Authenticate", "Bearer realm=\"api\"") + // Add non-sensitive headers + resp.Header.Set("Content-Type", "application/json") + resp.Header.Set("X-Request-Id", "req-123") + return resp, nil + }) + require.NoError(t, err) + + // Must read and close response body to trigger the streaming dump + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + + // Read the response dump file + modelDir := filepath.Join(tmpDir, "openai", "gpt-4") + respDumpPath := findDumpFile(t, modelDir, SuffixResponse) + respContent, err := os.ReadFile(respDumpPath) + require.NoError(t, err) + + content := string(respContent) + + // Verify sensitive headers are present but redacted + require.Contains(t, content, "Set-Cookie: sess...cure") + // Note: Go canonicalizes WWW-Authenticate to Www-Authenticate + // "Bearer realm=\"api\"" = 18 chars, first 4 = "Bear", last 4 = "api\"" + require.Contains(t, content, "Www-Authenticate: Bear...api\"") + + // Verify full secret values are NOT present + require.NotContains(t, content, "secret123") + require.NotContains(t, content, "realm=\"api\"") + + // Verify non-sensitive headers ARE present in full + require.Contains(t, content, "Content-Type: application/json") + require.Contains(t, content, "X-Request-Id: req-123") +} + +func TestMiddleware_EmptyBaseDir_ReturnsNil(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + middleware := NewMiddleware("", "openai", "gpt-4", uuid.New(), logger, quartz.NewMock(t)) + require.Nil(t, middleware) +} + +func TestMiddleware_PreservesRequestBody(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + clk := quartz.NewMock(t) + interceptionID := uuid.New() + + middleware := NewMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) + require.NotNil(t, middleware) + + originalBody := `{"messages": [{"role": "user", "content": "hello"}]}` + req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(originalBody))) + require.NoError(t, err) + + var capturedBody []byte + _, err = middleware(req, func(r *http.Request) (*http.Response, error) { + // Read the body in the next handler to verify it's still available + capturedBody, _ = io.ReadAll(r.Body) + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Proto: "HTTP/1.1", + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader([]byte(`{}`))), + }, nil + }) + require.NoError(t, err) + + // Verify the body was preserved for the next handler + require.Equal(t, originalBody, string(capturedBody)) +} + +func TestMiddleware_ModelWithSlash(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + clk := quartz.NewMock(t) + interceptionID := uuid.New() + + // Model with slash should have it replaced with dash + middleware := NewMiddleware(tmpDir, "google", "gemini/1.5-pro", interceptionID, logger, clk) + require.NotNil(t, middleware) + + req, err := http.NewRequest(http.MethodPost, "https://api.google.com/v1/chat", bytes.NewReader([]byte(`{}`))) + require.NoError(t, err) + + _, err = middleware(req, func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Proto: "HTTP/1.1", + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader([]byte(`{}`))), + }, nil + }) + require.NoError(t, err) + + // Verify files are created with sanitized model name + modelDir := filepath.Join(tmpDir, "google", "gemini-1.5-pro") + reqDumpPath := findDumpFile(t, modelDir, SuffixRequest) + _, err = os.Stat(reqDumpPath) + require.NoError(t, err) +} + +func TestPrettyPrintJSON(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input []byte + expected string + }{ + { + name: "empty", + input: []byte{}, + expected: "", + }, + { + name: "valid JSON", + input: []byte(`{"key":"value"}`), + expected: "{\n \"key\": \"value\"\n}", + }, + { + name: "invalid JSON returns as-is", + input: []byte("not json"), + expected: "not json", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + result := prettyPrintJSON(tc.input) + require.Equal(t, tc.expected, string(result)) + }) + } +} + +func TestMiddleware_AllSensitiveRequestHeaders(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + clk := quartz.NewMock(t) + interceptionID := uuid.New() + + middleware := NewMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) + require.NotNil(t, middleware) + + req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + require.NoError(t, err) + + // Set all sensitive headers + req.Header.Set("Authorization", "Bearer sk-secret-key") + req.Header.Set("X-Api-Key", "secret-api-key") + req.Header.Set("Api-Key", "another-secret") + req.Header.Set("X-Auth-Token", "auth-token-val") + req.Header.Set("Cookie", "session=abc123def") + req.Header.Set("Proxy-Authorization", "Basic proxy-creds") + req.Header.Set("X-Amz-Security-Token", "aws-security-token") + + _, err = middleware(req, func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Proto: "HTTP/1.1", + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader([]byte(`{}`))), + }, nil + }) + require.NoError(t, err) + + modelDir := filepath.Join(tmpDir, "openai", "gpt-4") + reqDumpPath := findDumpFile(t, modelDir, SuffixRequest) + reqContent, err := os.ReadFile(reqDumpPath) + require.NoError(t, err) + + content := string(reqContent) + + // Verify none of the full secret values are present + require.NotContains(t, content, "sk-secret-key") + require.NotContains(t, content, "secret-api-key") + require.NotContains(t, content, "another-secret") + require.NotContains(t, content, "auth-token-val") + require.NotContains(t, content, "abc123def") + require.NotContains(t, content, "proxy-creds") + require.NotContains(t, content, "aws-security-token") + require.NotContains(t, content, "google-api-key") + + // But headers themselves are present (redacted) + require.Contains(t, content, "Authorization:") + require.Contains(t, content, "X-Api-Key:") + require.Contains(t, content, "Api-Key:") + require.Contains(t, content, "X-Auth-Token:") + require.Contains(t, content, "Cookie:") + require.Contains(t, content, "Proxy-Authorization:") + require.Contains(t, content, "X-Amz-Security-Token:") +} diff --git a/intercept/apidump/headers.go b/intercept/apidump/headers.go new file mode 100644 index 0000000..0b4047d --- /dev/null +++ b/intercept/apidump/headers.go @@ -0,0 +1,34 @@ +package apidump + +// sensitiveRequestHeaders are headers that should be redacted from request dumps. +var sensitiveRequestHeaders = map[string]struct{}{ + "Authorization": {}, + "X-Api-Key": {}, + "Api-Key": {}, + "X-Auth-Token": {}, + "Cookie": {}, + "Proxy-Authorization": {}, + "X-Amz-Security-Token": {}, +} + +// sensitiveResponseHeaders are headers that should be redacted from response dumps. +// Note: header names use Go's canonical form (http.CanonicalHeaderKey). +var sensitiveResponseHeaders = map[string]struct{}{ + "Set-Cookie": {}, + "Www-Authenticate": {}, + "Proxy-Authenticate": {}, +} + +// redactHeaderValue redacts a sensitive header value, showing only partial content. +// For values >= 8 bytes: shows first 4 and last 4 bytes with "..." in between. +// For values < 8 bytes: shows first and last byte with "..." in between. +func redactHeaderValue(value string) string { + if len(value) >= 8 { + return value[:4] + "..." + value[len(value)-4:] + } + if len(value) >= 2 { + return value[:1] + "..." + value[len(value)-1:] + } + // Single character or empty - just return as-is + return value +} diff --git a/intercept/apidump/headers_test.go b/intercept/apidump/headers_test.go new file mode 100644 index 0000000..a178e61 --- /dev/null +++ b/intercept/apidump/headers_test.go @@ -0,0 +1,92 @@ +package apidump + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRedactHeaderValue(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "single char", + input: "a", + expected: "a", + }, + { + name: "two chars", + input: "ab", + expected: "a...b", + }, + { + name: "seven chars", + input: "abcdefg", + expected: "a...g", + }, + { + name: "eight chars - threshold", + input: "abcdefgh", + expected: "abcd...efgh", + }, + { + name: "long value", + input: "Bearer sk-secret-key-12345", + expected: "Bear...2345", + }, + { + name: "realistic api key", + input: "sk-proj-abc123xyz789", + expected: "sk-p...z789", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + result := redactHeaderValue(tc.input) + require.Equal(t, tc.expected, result) + }) + } +} + +func TestSensitiveHeaderLists(t *testing.T) { + t.Parallel() + + // Verify all expected sensitive request headers are in the list + expectedRequestHeaders := []string{ + "Authorization", + "X-Api-Key", + "Api-Key", + "X-Auth-Token", + "Cookie", + "Proxy-Authorization", + "X-Amz-Security-Token", + } + for _, h := range expectedRequestHeaders { + _, ok := sensitiveRequestHeaders[h] + require.True(t, ok, "expected %q to be in sensitiveRequestHeaders", h) + } + + // Verify all expected sensitive response headers are in the list + // Note: header names use Go's canonical form (http.CanonicalHeaderKey) + expectedResponseHeaders := []string{ + "Set-Cookie", + "Www-Authenticate", + "Proxy-Authenticate", + } + for _, h := range expectedResponseHeaders { + _, ok := sensitiveResponseHeaders[h] + require.True(t, ok, "expected %q to be in sensitiveResponseHeaders", h) + } +} diff --git a/intercept/apidump/streaming.go b/intercept/apidump/streaming.go new file mode 100644 index 0000000..1ad5121 --- /dev/null +++ b/intercept/apidump/streaming.go @@ -0,0 +1,72 @@ +package apidump + +import ( + "fmt" + "io" + "os" + "path/filepath" + "sync" +) + +// streamingBodyDumper wraps an io.ReadCloser and writes all data to a dump file +// as it's read, preserving streaming behavior. +type streamingBodyDumper struct { + body io.ReadCloser + dumpPath string + headerData []byte + logger func(err error) + + once sync.Once + file *os.File + initErr error +} + +func (s *streamingBodyDumper) init() { + s.once.Do(func() { + if err := os.MkdirAll(filepath.Dir(s.dumpPath), 0o755); err != nil { + s.initErr = fmt.Errorf("create dump dir: %w", err) + return + } + f, err := os.Create(s.dumpPath) + if err != nil { + s.initErr = fmt.Errorf("create dump file: %w", err) + return + } + s.file = f + // Write headers first. + if _, err := s.file.Write(s.headerData); err != nil { + s.initErr = fmt.Errorf("write headers: %w", err) + s.file.Close() + s.file = nil + } + }) +} + +func (s *streamingBodyDumper) Read(p []byte) (int, error) { + n, err := s.body.Read(p) + if n > 0 { + s.init() + if s.initErr != nil && s.logger != nil { + s.logger(s.initErr) + } + if s.file != nil { + // Write raw bytes as they stream through. + _, _ = s.file.Write(p[:n]) + } + } + return n, err +} + +func (s *streamingBodyDumper) Close() error { + // Ensure init() has completed to avoid racing with Read(). + s.init() + var closeErr error + if s.file != nil { + closeErr = s.file.Close() + } + bodyErr := s.body.Close() + if bodyErr != nil { + return bodyErr + } + return closeErr +} diff --git a/intercept/apidump/streaming_test.go b/intercept/apidump/streaming_test.go new file mode 100644 index 0000000..3d76555 --- /dev/null +++ b/intercept/apidump/streaming_test.go @@ -0,0 +1,125 @@ +package apidump + +import ( + "bytes" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "testing" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/quartz" + "github.com/google/uuid" + "github.com/stretchr/testify/require" +) + +func TestMiddleware_StreamingResponse(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + clk := quartz.NewMock(t) + interceptionID := uuid.New() + + middleware := NewMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) + require.NotNil(t, middleware) + + req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + require.NoError(t, err) + + // Simulate a streaming response with multiple chunks + chunks := []string{ + "data: {\"chunk\": 1}\n\n", + "data: {\"chunk\": 2}\n\n", + "data: {\"chunk\": 3}\n\n", + "data: [DONE]\n\n", + } + + // Create a pipe to simulate streaming + pr, pw := io.Pipe() + go func() { + for _, chunk := range chunks { + pw.Write([]byte(chunk)) + } + pw.Close() + }() + + resp, err := middleware(req, func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Proto: "HTTP/1.1", + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: pr, + }, nil + }) + require.NoError(t, err) + + // Read response in small chunks to simulate streaming consumption + var receivedData bytes.Buffer + buf := make([]byte, 16) + for { + n, err := resp.Body.Read(buf) + if n > 0 { + receivedData.Write(buf[:n]) + } + if err == io.EOF { + break + } + require.NoError(t, err) + } + require.NoError(t, resp.Body.Close()) + + // Verify we received all the data + expectedData := strings.Join(chunks, "") + require.Equal(t, expectedData, receivedData.String()) + + // Verify the dump file was created and contains all the streamed data + modelDir := filepath.Join(tmpDir, "openai", "gpt-4") + respDumpPath := findDumpFile(t, modelDir, SuffixResponse) + respContent, err := os.ReadFile(respDumpPath) + require.NoError(t, err) + + content := string(respContent) + require.Contains(t, content, "HTTP/1.1 200 OK") + require.Contains(t, content, "Content-Type: text/event-stream") + // All chunks should be in the dump + for _, chunk := range chunks { + require.Contains(t, content, chunk) + } +} + +func TestMiddleware_PreservesResponseBody(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) + clk := quartz.NewMock(t) + interceptionID := uuid.New() + + middleware := NewMiddleware(tmpDir, "openai", "gpt-4", interceptionID, logger, clk) + require.NotNil(t, middleware) + + req, err := http.NewRequest(http.MethodPost, "https://api.openai.com/v1/chat/completions", bytes.NewReader([]byte(`{}`))) + require.NoError(t, err) + + originalRespBody := `{"choices": [{"message": {"content": "hi"}}]}` + resp, err := middleware(req, func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Status: "200 OK", + Proto: "HTTP/1.1", + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader([]byte(originalRespBody))), + }, nil + }) + require.NoError(t, err) + + // Verify the response body is still readable after middleware + capturedBody, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, originalRespBody, string(capturedBody)) +} diff --git a/intercept/chatcompletions/base.go b/intercept/chatcompletions/base.go index b1cf7be..2610b3e 100644 --- a/intercept/chatcompletions/base.go +++ b/intercept/chatcompletions/base.go @@ -9,9 +9,11 @@ import ( "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" + "github.com/coder/quartz" "github.com/google/uuid" "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/option" @@ -23,10 +25,9 @@ import ( ) type interceptionBase struct { - id uuid.UUID - req *ChatCompletionNewParamsWrapper - baseURL string - key string + id uuid.UUID + req *ChatCompletionNewParamsWrapper + cfg config.OpenAI logger slog.Logger tracer trace.Tracer @@ -35,8 +36,13 @@ type interceptionBase struct { mcpProxy mcp.ServerProxier } -func (i *interceptionBase) newCompletionsService(baseURL string, key string) openai.ChatCompletionService { - opts := []option.RequestOption{option.WithAPIKey(key), option.WithBaseURL(baseURL)} +func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService { + opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.cfg.BaseURL)} + + // Add API dump middleware if configured + if mw := apidump.NewMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { + opts = append(opts, option.WithMiddleware(mw)) + } return openai.NewChatCompletionService(opts...) } diff --git a/intercept/chatcompletions/blocking.go b/intercept/chatcompletions/blocking.go index c565bff..54a4c16 100644 --- a/intercept/chatcompletions/blocking.go +++ b/intercept/chatcompletions/blocking.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/coder/aibridge/config" "github.com/coder/aibridge/intercept/eventstream" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" @@ -25,13 +26,12 @@ type BlockingInterception struct { interceptionBase } -func NewBlockingInterceptor(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string, tracer trace.Tracer) *BlockingInterception { +func NewBlockingInterceptor(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg config.OpenAI, tracer trace.Tracer) *BlockingInterception { return &BlockingInterception{interceptionBase: interceptionBase{ - id: id, - req: req, - baseURL: baseURL, - key: key, - tracer: tracer, + id: id, + req: req, + cfg: cfg, + tracer: tracer, }} } @@ -55,7 +55,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...)) defer tracing.EndSpanErr(span, &outErr) - svc := i.newCompletionsService(i.baseURL, i.key) + svc := i.newCompletionsService() logger := i.logger.With(slog.F("model", i.req.Model)) var ( diff --git a/intercept/chatcompletions/streaming.go b/intercept/chatcompletions/streaming.go index f8db748..557c95e 100644 --- a/intercept/chatcompletions/streaming.go +++ b/intercept/chatcompletions/streaming.go @@ -10,6 +10,7 @@ import ( "strings" "time" + "github.com/coder/aibridge/config" "github.com/coder/aibridge/intercept/eventstream" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" @@ -28,13 +29,12 @@ type StreamingInterception struct { interceptionBase } -func NewStreamingInterceptor(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string, tracer trace.Tracer) *StreamingInterception { +func NewStreamingInterceptor(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg config.OpenAI, tracer trace.Tracer) *StreamingInterception { return &StreamingInterception{interceptionBase: interceptionBase{ - id: id, - req: req, - baseURL: baseURL, - key: key, - tracer: tracer, + id: id, + req: req, + cfg: cfg, + tracer: tracer, }} } @@ -80,7 +80,7 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re defer cancel() r = r.WithContext(ctx) // Rewire context for SSE cancellation. - svc := i.newCompletionsService(i.baseURL, i.key) + svc := i.newCompletionsService() logger := i.logger.With(slog.F("model", i.req.Model)) streamCtx, streamCancel := context.WithCancelCause(ctx) diff --git a/intercept/messages/base.go b/intercept/messages/base.go index a5bea47..4c993e1 100644 --- a/intercept/messages/base.go +++ b/intercept/messages/base.go @@ -19,9 +19,11 @@ import ( "github.com/aws/aws-sdk-go-v2/credentials" aibconfig "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" + "github.com/coder/quartz" "github.com/google/uuid" "go.opentelemetry.io/otel/attribute" @@ -153,6 +155,11 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio opts = append(opts, option.WithAPIKey(i.cfg.Key)) opts = append(opts, option.WithBaseURL(i.cfg.BaseURL)) + // Add API dump middleware if configured + if mw := apidump.NewMiddleware(i.cfg.APIDumpDir, aibconfig.ProviderAnthropic, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { + opts = append(opts, option.WithMiddleware(mw)) + } + if i.bedrockCfg != nil { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() diff --git a/intercept/responses/base.go b/intercept/responses/base.go index f369dc1..923c754 100644 --- a/intercept/responses/base.go +++ b/intercept/responses/base.go @@ -17,10 +17,12 @@ import ( "cdr.dev/slog/v3" "github.com/coder/aibridge/config" aibcontext "github.com/coder/aibridge/context" + "github.com/coder/aibridge/intercept/apidump" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/metrics" "github.com/coder/aibridge/recorder" "github.com/coder/aibridge/tracing" + "github.com/coder/quartz" "github.com/google/uuid" "github.com/openai/openai-go/v3/option" "github.com/openai/openai-go/v3/responses" @@ -38,8 +40,7 @@ type responsesInterceptionBase struct { id uuid.UUID req *ResponsesNewParamsWrapper reqPayload []byte - baseURL string - apiKey string + cfg config.OpenAI model string recorder recorder.Recorder mcpProxy mcp.ServerProxier @@ -48,9 +49,11 @@ type responsesInterceptionBase struct { } func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService { - opts := []option.RequestOption{ - option.WithBaseURL(i.baseURL), - option.WithAPIKey(i.apiKey), + opts := []option.RequestOption{option.WithBaseURL(i.cfg.BaseURL), option.WithAPIKey(i.cfg.Key)} + + // Add API dump middleware if configured + if mw := apidump.NewMiddleware(i.cfg.APIDumpDir, config.ProviderOpenAI, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil { + opts = append(opts, option.WithMiddleware(mw)) } return responses.NewResponseService(opts...) diff --git a/intercept/responses/blocking.go b/intercept/responses/blocking.go index 6161d6f..dd909fa 100644 --- a/intercept/responses/blocking.go +++ b/intercept/responses/blocking.go @@ -5,6 +5,7 @@ import ( "net/http" "cdr.dev/slog/v3" + "github.com/coder/aibridge/config" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" "github.com/google/uuid" @@ -15,14 +16,13 @@ type BlockingResponsesInterceptor struct { responsesInterceptionBase } -func NewBlockingInterceptor(id uuid.UUID, req *ResponsesNewParamsWrapper, reqPayload []byte, baseURL string, key string, model string) *BlockingResponsesInterceptor { +func NewBlockingInterceptor(id uuid.UUID, req *ResponsesNewParamsWrapper, reqPayload []byte, cfg config.OpenAI, model string) *BlockingResponsesInterceptor { return &BlockingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ id: id, req: req, reqPayload: reqPayload, - baseURL: baseURL, - apiKey: key, + cfg: cfg, model: model, }, } diff --git a/intercept/responses/streaming.go b/intercept/responses/streaming.go index 23793e3..f7fc7f1 100644 --- a/intercept/responses/streaming.go +++ b/intercept/responses/streaming.go @@ -8,6 +8,7 @@ import ( "time" "cdr.dev/slog/v3" + "github.com/coder/aibridge/config" "github.com/coder/aibridge/intercept/eventstream" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/recorder" @@ -25,14 +26,13 @@ type StreamingResponsesInterceptor struct { responsesInterceptionBase } -func NewStreamingInterceptor(id uuid.UUID, req *ResponsesNewParamsWrapper, reqPayload []byte, baseURL string, key string, model string) *StreamingResponsesInterceptor { +func NewStreamingInterceptor(id uuid.UUID, req *ResponsesNewParamsWrapper, reqPayload []byte, cfg config.OpenAI, model string) *StreamingResponsesInterceptor { return &StreamingResponsesInterceptor{ responsesInterceptionBase: responsesInterceptionBase{ id: id, req: req, reqPayload: reqPayload, - baseURL: baseURL, - apiKey: key, + cfg: cfg, model: model, }, } diff --git a/provider/anthropic.go b/provider/anthropic.go index b90fa5f..d6e8d24 100644 --- a/provider/anthropic.go +++ b/provider/anthropic.go @@ -46,6 +46,9 @@ func NewAnthropic(cfg config.Anthropic, bedrockCfg *config.AWSBedrock) *Anthropi if cfg.Key == "" { cfg.Key = os.Getenv("ANTHROPIC_API_KEY") } + if cfg.APIDumpDir == "" { + cfg.APIDumpDir = os.Getenv("BRIDGE_DUMP_DIR") + } if cfg.CircuitBreaker != nil { cfg.CircuitBreaker.IsFailure = anthropicIsFailure cfg.CircuitBreaker.OpenErrorResponse = anthropicOpenErrorResponse diff --git a/provider/openai.go b/provider/openai.go index 24ce3d8..951770d 100644 --- a/provider/openai.go +++ b/provider/openai.go @@ -28,8 +28,7 @@ var openAIOpenErrorResponse = func() []byte { // OpenAI allows for interactions with the OpenAI API. type OpenAI struct { - baseURL string - key string + cfg config.OpenAI circuitBreaker *config.CircuitBreaker } @@ -39,18 +38,19 @@ func NewOpenAI(cfg config.OpenAI) *OpenAI { if cfg.BaseURL == "" { cfg.BaseURL = "https://api.openai.com/v1/" } - if cfg.Key == "" { cfg.Key = os.Getenv("OPENAI_API_KEY") } + if cfg.APIDumpDir == "" { + cfg.APIDumpDir = os.Getenv("BRIDGE_DUMP_DIR") + } if cfg.CircuitBreaker != nil { cfg.CircuitBreaker.OpenErrorResponse = openAIOpenErrorResponse } return &OpenAI{ - baseURL: cfg.BaseURL, - key: cfg.Key, + cfg: cfg, circuitBreaker: cfg.CircuitBreaker, } } @@ -103,9 +103,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace } if req.Stream { - interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.baseURL, p.key, tracer) + interceptor = chatcompletions.NewStreamingInterceptor(id, &req, p.cfg, tracer) } else { - interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.baseURL, p.key, tracer) + interceptor = chatcompletions.NewBlockingInterceptor(id, &req, p.cfg, tracer) } case routeResponses: @@ -114,9 +114,9 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace return nil, fmt.Errorf("unmarshal request body: %w", err) } if req.Stream { - interceptor = responses.NewStreamingInterceptor(id, &req, payload, p.baseURL, p.key, string(req.Model)) + interceptor = responses.NewStreamingInterceptor(id, &req, payload, p.cfg, string(req.Model)) } else { - interceptor = responses.NewBlockingInterceptor(id, &req, payload, p.baseURL, p.key, string(req.Model)) + interceptor = responses.NewBlockingInterceptor(id, &req, payload, p.cfg, string(req.Model)) } default: @@ -128,7 +128,7 @@ func (p *OpenAI) CreateInterceptor(w http.ResponseWriter, r *http.Request, trace } func (p *OpenAI) BaseURL() string { - return p.baseURL + return p.cfg.BaseURL } func (p *OpenAI) AuthHeader() string { @@ -140,7 +140,7 @@ func (p *OpenAI) InjectAuthHeader(headers *http.Header) { headers = &http.Header{} } - headers.Set(p.AuthHeader(), "Bearer "+p.key) + headers.Set(p.AuthHeader(), "Bearer "+p.cfg.Key) } func (p *OpenAI) CircuitBreakerConfig() *config.CircuitBreaker {