diff --git a/bridge_integration_test.go b/bridge_integration_test.go index b0de5f8..b8f0435 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -126,7 +126,9 @@ func TestAnthropicMessages(t *testing.T) { recorderClient := &mockRecorderClient{} logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(cfg(srv.URL, apiKey))}, logger, recorderClient, mcp.NewServerProxyManager(nil)) + provider, err := aibridge.NewAnthropicProvider(cfg(srv.URL, apiKey)) + require.NoError(t, err) + b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, logger, recorderClient, mcp.NewServerProxyManager(nil)) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -229,7 +231,9 @@ func TestOpenAIChatCompletions(t *testing.T) { recorderClient := &mockRecorderClient{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(cfg(srv.URL, apiKey))}, logger, recorderClient, mcp.NewServerProxyManager(nil)) + provider, err := aibridge.NewOpenAIProvider(cfg(srv.URL, apiKey)) + require.NoError(t, err) + b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, recorderClient, mcp.NewServerProxyManager(nil)) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -294,7 +298,11 @@ func TestSimple(t *testing.T) { fixture: antSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(cfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil)) + provider, err := aibridge.NewAnthropicProvider(cfg(addr, apiKey)) + if err != nil { + return nil, err + } + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil)) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -332,7 +340,11 @@ func TestSimple(t *testing.T) { fixture: oaiSimple, configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(cfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil)) + provider, err := aibridge.NewOpenAIProvider(cfg(addr, apiKey)) + if err != nil { + return nil, err + } + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil)) }, getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) { if streaming { @@ -470,7 +482,8 @@ func TestFallthrough(t *testing.T) { fixture: antFallthrough, configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - provider := aibridge.NewAnthropicProvider(cfg(addr, apiKey)) + provider, err := aibridge.NewAnthropicProvider(cfg(addr, apiKey)) + require.NoError(t, err) bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil)) require.NoError(t, err) return provider, bridge @@ -481,7 +494,8 @@ func TestFallthrough(t *testing.T) { fixture: oaiFallthrough, configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - provider := aibridge.NewOpenAIProvider(cfg(addr, apiKey)) + provider, err := aibridge.NewOpenAIProvider(cfg(addr, apiKey)) + require.NoError(t, err) bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil)) require.NoError(t, err) return provider, bridge @@ -586,7 +600,11 @@ func TestAnthropicInjectedTools(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(cfg(addr, apiKey))}, logger, client, srvProxyMgr) + provider, err := aibridge.NewAnthropicProvider(cfg(addr, apiKey)) + if err != nil { + return nil, err + } + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, srvProxyMgr) } // Build the requirements & make the assertions which are common to all providers. @@ -667,7 +685,11 @@ func TestOpenAIInjectedTools(t *testing.T) { configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(cfg(addr, apiKey))}, logger, client, srvProxyMgr) + provider, err := aibridge.NewOpenAIProvider(cfg(addr, apiKey)) + if err != nil { + return nil, err + } + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, srvProxyMgr) } // Build the requirements & make the assertions which are common to all providers. @@ -851,7 +873,11 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createAnthropicMessagesReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(cfg(addr, apiKey))}, logger, client, srvProxyMgr) + provider, err := aibridge.NewAnthropicProvider(cfg(addr, apiKey)) + if err != nil { + return nil, err + } + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, srvProxyMgr) }, responseHandlerFn: func(streaming bool, resp *http.Response) { if streaming { @@ -876,7 +902,11 @@ func TestErrorHandling(t *testing.T) { createRequestFunc: createOpenAIChatCompletionsReq, configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(cfg(addr, apiKey))}, logger, client, srvProxyMgr) + provider, err := aibridge.NewOpenAIProvider(cfg(addr, apiKey)) + if err != nil { + return nil, err + } + return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, srvProxyMgr) }, responseHandlerFn: func(streaming bool, resp *http.Response) { if streaming { @@ -1152,9 +1182,6 @@ func createMockMCPSrv(t *testing.T) http.Handler { return server.NewStreamableHTTPServer(s) } -func cfg(url, key string) aibridge.ProviderConfig { - return aibridge.ProviderConfig{ - BaseURL: url, - Key: key, - } +func cfg(url, key string) *aibridge.ProviderConfig { + return aibridge.NewProviderConfig(url, key, "") } diff --git a/config.go b/config.go index 4c99eb6..60c08aa 100644 --- a/config.go +++ b/config.go @@ -1,10 +1,60 @@ package aibridge +import "go.uber.org/atomic" + type ProviderConfig struct { - BaseURL, Key string + baseURL, key atomic.String + upstreamLoggingDir atomic.String + enableUpstreamLogging atomic.Bool +} + +// NewProviderConfig creates a new ProviderConfig with the given values. +func NewProviderConfig(baseURL, key, upstreamLoggingDir string) *ProviderConfig { + cfg := &ProviderConfig{} + cfg.baseURL.Store(baseURL) + cfg.key.Store(key) + cfg.upstreamLoggingDir.Store(upstreamLoggingDir) + return cfg +} + +// BaseURL returns the base URL for the provider. +func (c *ProviderConfig) BaseURL() string { + return c.baseURL.Load() +} + +// SetBaseURL sets the base URL for the provider. +func (c *ProviderConfig) SetBaseURL(baseURL string) { + c.baseURL.Store(baseURL) +} + +// Key returns the API key for the provider. +func (c *ProviderConfig) Key() string { + return c.key.Load() +} + +// SetKey sets the API key for the provider. +func (c *ProviderConfig) SetKey(key string) { + c.key.Store(key) +} + +// UpstreamLoggingDir returns the base directory for upstream logging. +// If empty, the OS's tempdir will be used. +// Logs are written to $UpstreamLoggingDir/$provider/$model/$interceptionID.{req,res}.log +func (c *ProviderConfig) UpstreamLoggingDir() string { + return c.upstreamLoggingDir.Load() +} + +// SetUpstreamLoggingDir sets the base directory for upstream logging. +func (c *ProviderConfig) SetUpstreamLoggingDir(dir string) { + c.upstreamLoggingDir.Store(dir) +} + +// SetEnableUpstreamLogging enables or disables upstream logging at runtime. +func (c *ProviderConfig) SetEnableUpstreamLogging(enabled bool) { + c.enableUpstreamLogging.Store(enabled) } -type Config struct { - OpenAI ProviderConfig - Anthropic ProviderConfig +// IsUpstreamLoggingEnabled returns whether upstream logging is currently enabled. +func (c *ProviderConfig) IsUpstreamLoggingEnabled() bool { + return c.enableUpstreamLogging.Load() } diff --git a/go.mod b/go.mod index 6c241fe..66c925f 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/stretchr/testify v1.10.0 github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/sjson v1.2.5 + go.uber.org/atomic v1.11.0 go.uber.org/goleak v1.3.0 go.uber.org/mock v0.6.0 golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b diff --git a/go.sum b/go.sum index 7815785..119563d 100644 --- a/go.sum +++ b/go.sum @@ -94,6 +94,8 @@ go.opentelemetry.io/otel/sdk v1.16.0 h1:Z1Ok1YsijYL0CSJpHt4cS3wDDh7p572grzNrBMiM go.opentelemetry.io/otel/sdk v1.16.0/go.mod h1:tMsIuKXuuIWPBAOrH+eHtvhTL+SntFtXF9QD68aP6p4= go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qqW2d/s= go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= diff --git a/intercept_anthropic_messages_base.go b/intercept_anthropic_messages_base.go index 56c9744..35e8642 100644 --- a/intercept_anthropic_messages_base.go +++ b/intercept_anthropic_messages_base.go @@ -14,8 +14,8 @@ type AnthropicMessagesInterceptionBase struct { id uuid.UUID req *MessageNewParamsWrapper - baseURL, key string - logger slog.Logger + cfg *ProviderConfig + logger slog.Logger recorder Recorder mcpProxy mcp.ServerProxier diff --git a/intercept_anthropic_messages_blocking.go b/intercept_anthropic_messages_blocking.go index d80a25d..ccfff75 100644 --- a/intercept_anthropic_messages_blocking.go +++ b/intercept_anthropic_messages_blocking.go @@ -22,12 +22,11 @@ type AnthropicMessagesBlockingInterception struct { AnthropicMessagesInterceptionBase } -func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, baseURL, key string) *AnthropicMessagesBlockingInterception { +func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg *ProviderConfig) *AnthropicMessagesBlockingInterception { return &AnthropicMessagesBlockingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{ - id: id, - req: req, - baseURL: baseURL, - key: key, + id: id, + req: req, + cfg: cfg, }} } @@ -58,7 +57,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr opts := []option.RequestOption{option.WithRequestTimeout(time.Second * 60)} // TODO: configurable timeout - client := newAnthropicClient(i.baseURL, i.key, opts...) + client := newAnthropicClient(i.logger, i.cfg, i.id.String(), i.Model(), opts...) messages := i.req.MessageNewParams logger := i.logger.With(slog.F("model", i.req.Model)) diff --git a/intercept_anthropic_messages_streaming.go b/intercept_anthropic_messages_streaming.go index c2ad7e0..4437a44 100644 --- a/intercept_anthropic_messages_streaming.go +++ b/intercept_anthropic_messages_streaming.go @@ -25,12 +25,11 @@ type AnthropicMessagesStreamingInterception struct { AnthropicMessagesInterceptionBase } -func NewAnthropicMessagesStreamingInterception(id uuid.UUID, req *MessageNewParamsWrapper, baseURL, key string) *AnthropicMessagesStreamingInterception { +func NewAnthropicMessagesStreamingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg *ProviderConfig) *AnthropicMessagesStreamingInterception { return &AnthropicMessagesStreamingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{ - id: id, - req: req, - baseURL: baseURL, - key: key, + id: id, + req: req, + cfg: cfg, }} } @@ -95,7 +94,7 @@ func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseW _ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes. }() - client := newAnthropicClient(i.baseURL, i.key) + client := newAnthropicClient(i.logger, i.cfg, i.id.String(), i.Model()) messages := i.req.MessageNewParams // Accumulate usage across the entire streaming interaction (including tool reinvocations). diff --git a/intercept_openai_chat_base.go b/intercept_openai_chat_base.go index 36b8ff0..4be1c77 100644 --- a/intercept_openai_chat_base.go +++ b/intercept_openai_chat_base.go @@ -17,8 +17,8 @@ type OpenAIChatInterceptionBase struct { id uuid.UUID req *ChatCompletionNewParamsWrapper - baseURL, key string - logger slog.Logger + cfg *ProviderConfig + logger slog.Logger recorder Recorder mcpProxy mcp.ServerProxier diff --git a/intercept_openai_chat_blocking.go b/intercept_openai_chat_blocking.go index 3b1fa7e..f7321b9 100644 --- a/intercept_openai_chat_blocking.go +++ b/intercept_openai_chat_blocking.go @@ -23,12 +23,11 @@ type OpenAIBlockingChatInterception struct { OpenAIChatInterceptionBase } -func NewOpenAIBlockingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string) *OpenAIBlockingChatInterception { +func NewOpenAIBlockingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg *ProviderConfig) *OpenAIBlockingChatInterception { return &OpenAIBlockingChatInterception{OpenAIChatInterceptionBase: OpenAIChatInterceptionBase{ - id: id, - req: req, - baseURL: baseURL, - key: key, + id: id, + req: req, + cfg: cfg, }} } @@ -42,7 +41,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r } ctx := r.Context() - client := newOpenAIClient(i.baseURL, i.key) + client := newOpenAIClient(i.logger, i.cfg, i.id.String(), i.Model()) logger := i.logger.With(slog.F("model", i.req.Model)) var ( diff --git a/intercept_openai_chat_streaming.go b/intercept_openai_chat_streaming.go index 0c5f554..b134930 100644 --- a/intercept_openai_chat_streaming.go +++ b/intercept_openai_chat_streaming.go @@ -25,12 +25,11 @@ type OpenAIStreamingChatInterception struct { OpenAIChatInterceptionBase } -func NewOpenAIStreamingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string) *OpenAIStreamingChatInterception { +func NewOpenAIStreamingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg *ProviderConfig) *OpenAIStreamingChatInterception { return &OpenAIStreamingChatInterception{OpenAIChatInterceptionBase: OpenAIChatInterceptionBase{ - id: id, - req: req, - baseURL: baseURL, - key: key, + id: id, + req: req, + cfg: cfg, }} } @@ -65,7 +64,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter, defer cancel() r = r.WithContext(ctx) // Rewire context for SSE cancellation. - client := newOpenAIClient(i.baseURL, i.key) + client := newOpenAIClient(i.logger, i.cfg, i.id.String(), i.Model()) logger := i.logger.With(slog.F("model", i.req.Model)) streamCtx, streamCancel := context.WithCancelCause(ctx) diff --git a/provider_anthropic.go b/provider_anthropic.go index 192b230..e67eeb7 100644 --- a/provider_anthropic.go +++ b/provider_anthropic.go @@ -8,6 +8,7 @@ import ( "net/http" "os" + "cdr.dev/slog" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/option" "github.com/anthropics/anthropic-sdk-go/shared" @@ -19,7 +20,7 @@ var _ Provider = &AnthropicProvider{} // AnthropicProvider allows for interactions with the Anthropic API. type AnthropicProvider struct { - baseURL, key string + cfg *ProviderConfig } const ( @@ -28,18 +29,21 @@ const ( routeMessages = "/anthropic/v1/messages" // https://docs.anthropic.com/en/api/messages ) -func NewAnthropicProvider(cfg ProviderConfig) *AnthropicProvider { - if cfg.BaseURL == "" { - cfg.BaseURL = "https://api.anthropic.com/" +func NewAnthropicProvider(cfg *ProviderConfig) (*AnthropicProvider, error) { + if cfg == nil { + return nil, fmt.Errorf("ProviderConfig cannot be nil") } - if cfg.Key == "" { - cfg.Key = os.Getenv("ANTHROPIC_API_KEY") + + if cfg.BaseURL() == "" { + cfg.SetBaseURL("https://api.anthropic.com/") + } + if cfg.Key() == "" { + cfg.SetKey(os.Getenv("ANTHROPIC_API_KEY")) } return &AnthropicProvider{ - baseURL: cfg.BaseURL, - key: cfg.Key, - } + cfg: cfg, + }, nil } func (p *AnthropicProvider) Name() string { @@ -74,17 +78,17 @@ func (p *AnthropicProvider) CreateInterceptor(w http.ResponseWriter, r *http.Req } if req.Stream { - return NewAnthropicMessagesStreamingInterception(id, &req, p.baseURL, p.key), nil + return NewAnthropicMessagesStreamingInterception(id, &req, p.cfg), nil } - return NewAnthropicMessagesBlockingInterception(id, &req, p.baseURL, p.key), nil + return NewAnthropicMessagesBlockingInterception(id, &req, p.cfg), nil } return nil, UnknownRoute } func (p *AnthropicProvider) BaseURL() string { - return p.baseURL + return p.cfg.BaseURL() } func (p *AnthropicProvider) AuthHeader() string { @@ -96,12 +100,18 @@ func (p *AnthropicProvider) InjectAuthHeader(headers *http.Header) { headers = &http.Header{} } - headers.Set(p.AuthHeader(), p.key) + headers.Set(p.AuthHeader(), p.cfg.Key()) } -func newAnthropicClient(baseURL, key string, opts ...option.RequestOption) anthropic.Client { - opts = append(opts, option.WithAPIKey(key)) - opts = append(opts, option.WithBaseURL(baseURL)) +func newAnthropicClient(logger slog.Logger, cfg *ProviderConfig, id, model string, opts ...option.RequestOption) anthropic.Client { + opts = append(opts, option.WithAPIKey(cfg.Key())) + opts = append(opts, option.WithBaseURL(cfg.BaseURL())) + + if cfg.IsUpstreamLoggingEnabled() { + if middleware := createLoggingMiddleware(logger, cfg, ProviderAnthropic, id, model); middleware != nil { + opts = append(opts, option.WithMiddleware(middleware)) + } + } return anthropic.NewClient(opts...) } diff --git a/provider_openai.go b/provider_openai.go index 5779e05..8b645ca 100644 --- a/provider_openai.go +++ b/provider_openai.go @@ -7,6 +7,7 @@ import ( "net/http" "os" + "cdr.dev/slog" "github.com/google/uuid" "github.com/openai/openai-go/v2" "github.com/openai/openai-go/v2/option" @@ -16,7 +17,7 @@ var _ Provider = &OpenAIProvider{} // OpenAIProvider allows for interactions with the OpenAI API. type OpenAIProvider struct { - baseURL, key string + cfg *ProviderConfig } const ( @@ -25,19 +26,22 @@ const ( routeChatCompletions = "/openai/v1/chat/completions" // https://platform.openai.com/docs/api-reference/chat ) -func NewOpenAIProvider(cfg ProviderConfig) *OpenAIProvider { - if cfg.BaseURL == "" { - cfg.BaseURL = "https://api.openai.com/v1/" +func NewOpenAIProvider(cfg *ProviderConfig) (*OpenAIProvider, error) { + if cfg == nil { + return nil, fmt.Errorf("ProviderConfig cannot be nil") } - if cfg.Key == "" { - cfg.Key = os.Getenv("OPENAI_API_KEY") + if cfg.BaseURL() == "" { + cfg.SetBaseURL("https://api.openai.com/v1/") } - return &OpenAIProvider{ - baseURL: cfg.BaseURL, - key: cfg.Key, + if cfg.Key() == "" { + cfg.SetKey(os.Getenv("OPENAI_API_KEY")) } + + return &OpenAIProvider{ + cfg: cfg, + }, nil } func (p *OpenAIProvider) Name() string { @@ -76,9 +80,9 @@ func (p *OpenAIProvider) CreateInterceptor(w http.ResponseWriter, r *http.Reques } if req.Stream { - return NewOpenAIStreamingChatInterception(id, &req, p.baseURL, p.key), nil + return NewOpenAIStreamingChatInterception(id, &req, p.cfg), nil } else { - return NewOpenAIBlockingChatInterception(id, &req, p.baseURL, p.key), nil + return NewOpenAIBlockingChatInterception(id, &req, p.cfg), nil } } @@ -86,7 +90,7 @@ func (p *OpenAIProvider) CreateInterceptor(w http.ResponseWriter, r *http.Reques } func (p *OpenAIProvider) BaseURL() string { - return p.baseURL + return p.cfg.BaseURL() } func (p *OpenAIProvider) AuthHeader() string { @@ -98,13 +102,19 @@ func (p *OpenAIProvider) InjectAuthHeader(headers *http.Header) { headers = &http.Header{} } - headers.Set(p.AuthHeader(), "Bearer "+p.key) + headers.Set(p.AuthHeader(), "Bearer "+p.cfg.Key()) } -func newOpenAIClient(baseURL, key string) openai.Client { +func newOpenAIClient(logger slog.Logger, cfg *ProviderConfig, id, model string) openai.Client { var opts []option.RequestOption - opts = append(opts, option.WithAPIKey(key)) - opts = append(opts, option.WithBaseURL(baseURL)) + opts = append(opts, option.WithAPIKey(cfg.Key())) + opts = append(opts, option.WithBaseURL(cfg.BaseURL())) + + if cfg.IsUpstreamLoggingEnabled() { + if middleware := createLoggingMiddleware(logger, cfg, ProviderOpenAI, id, model); middleware != nil { + opts = append(opts, option.WithMiddleware(middleware)) + } + } return openai.NewClient(opts...) } diff --git a/request_logger.go b/request_logger.go new file mode 100644 index 0000000..0a71c4a --- /dev/null +++ b/request_logger.go @@ -0,0 +1,121 @@ +package aibridge + +import ( + "context" + "fmt" + "log" + "net/http" + "net/http/httputil" + "os" + "path/filepath" + "strings" + + "cdr.dev/slog" +) + +// SanitizeModelName makes a model name safe for use as a directory name. +// Replaces filesystem-unsafe characters with underscores. +func SanitizeModelName(model string) string { + repl := "_" + replacer := strings.NewReplacer( + "/", repl, + "\\", repl, + ":", repl, + "*", repl, + "?", repl, + "\"", repl, + "<", repl, + ">", repl, + "|", repl, + ) + return replacer.Replace(model) +} + +// logUpstreamRequest logs an HTTP request with the given ID and model name. +// The prefix format is: [req] [id] [model] +func logUpstreamRequest(logger *log.Logger, id, model string, req *http.Request) { + if logger == nil { + return + } + + if reqDump, err := httputil.DumpRequest(req, true); err == nil { + logger.Printf("[req] [%s] [%s] %s", id, model, reqDump) + } +} + +// logUpstreamResponse logs an HTTP response with the given ID and model name. +// The prefix format is: [res] [id] [model] +func logUpstreamResponse(logger *log.Logger, id, model string, resp *http.Response) { + if logger == nil { + return + } + + if respDump, err := httputil.DumpResponse(resp, true); err == nil { + logger.Printf("[res] [%s] [%s] %s", id, model, respDump) + } +} + +// logUpstreamError logs an error that occurred during request/response processing. +// The prefix format is: [res] [id] [model] Error: +func logUpstreamError(logger *log.Logger, id, model string, err error) { + if logger == nil { + return + } + + logger.Printf("[res] [%s] [%s] Error: %v", id, model, err) +} + +// createLoggingMiddleware creates a middleware function that logs requests and responses. +// Logs are written to $baseDir/$provider/$model/$id.req.log and $baseDir/$provider/$model/$id.res.log +// where baseDir is from cfg.UpstreamLoggingDir or os.TempDir() if not specified. +// Returns nil if logging setup fails, logging errors via the provided logger. +func createLoggingMiddleware(logger slog.Logger, cfg *ProviderConfig, provider, id, model string) func(*http.Request, func(*http.Request) (*http.Response, error)) (*http.Response, error) { + ctx := context.Background() + safeModel := SanitizeModelName(model) + + baseDir := cfg.UpstreamLoggingDir() + if baseDir == "" { + baseDir = os.TempDir() + } + + logDir := filepath.Join(baseDir, provider, safeModel) + + // Create the directory structure if it doesn't exist + if err := os.MkdirAll(logDir, 0755); err != nil { + logger.Warn(ctx, "failed to create log directory", slog.Error(err), slog.F("dir", logDir)) + return nil + } + + reqLogPath := filepath.Join(logDir, fmt.Sprintf("%s.req.log", id)) + resLogPath := filepath.Join(logDir, fmt.Sprintf("%s.res.log", id)) + + reqLogFile, err := os.OpenFile(reqLogPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + logger.Warn(ctx, "failed to open request log file", slog.Error(err), slog.F("path", reqLogPath)) + return nil + } + + resLogFile, err := os.OpenFile(resLogPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + reqLogFile.Close() + logger.Warn(ctx, "failed to open response log file", slog.Error(err), slog.F("path", resLogPath)) + return nil + } + + reqLogger := log.New(reqLogFile, "", log.LstdFlags) + resLogger := log.New(resLogFile, "", log.LstdFlags) + + return func(req *http.Request, next func(*http.Request) (*http.Response, error)) (*http.Response, error) { + logUpstreamRequest(reqLogger, id, model, req) + + resp, err := next(req) + if err != nil { + logUpstreamError(resLogger, id, model, err) + return resp, err + } + + logUpstreamResponse(resLogger, id, model, resp) + + return resp, err + } +} diff --git a/request_logger_test.go b/request_logger_test.go new file mode 100644 index 0000000..d451435 --- /dev/null +++ b/request_logger_test.go @@ -0,0 +1,172 @@ +package aibridge_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "cdr.dev/slog" + "cdr.dev/slog/sloggers/slogtest" + "github.com/coder/aibridge" + "github.com/coder/aibridge/mcp" + "github.com/stretchr/testify/require" + "golang.org/x/tools/txtar" +) + +func TestRequestLogging(t *testing.T) { + t.Parallel() + + testCases := []struct { + provider string + fixture []byte + route string + createProvider func(*aibridge.ProviderConfig) (aibridge.Provider, error) + }{ + { + provider: aibridge.ProviderAnthropic, + fixture: antSimple, + route: "/anthropic/v1/messages", + createProvider: func(cfg *aibridge.ProviderConfig) (aibridge.Provider, error) { + return aibridge.NewAnthropicProvider(cfg) + }, + }, + { + provider: aibridge.ProviderOpenAI, + fixture: oaiSimple, + route: "/openai/v1/chat/completions", + createProvider: func(cfg *aibridge.ProviderConfig) (aibridge.Provider, error) { + return aibridge.NewOpenAIProvider(cfg) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.provider, func(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug) + + // Use a temp dir for this test + tmpDir := t.TempDir() + + // Parse fixture + arc := txtar.Parse(tc.fixture) + files := filesMap(arc) + require.Contains(t, files, fixtureRequest) + require.Contains(t, files, fixtureNonStreamingResponse) + + // Create mock server + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(files[fixtureNonStreamingResponse]) + })) + t.Cleanup(srv.Close) + + cfg := aibridge.NewProviderConfig(srv.URL, apiKey, tmpDir) + cfg.SetEnableUpstreamLogging(true) + + provider, err := tc.createProvider(cfg) + require.NoError(t, err) + client := &mockRecorderClient{} + mcpProxy := mcp.NewServerProxyManager(nil) + + bridge, err := aibridge.NewRequestBridge(context.Background(), []aibridge.Provider{provider}, logger, client, mcpProxy) + require.NoError(t, err) + t.Cleanup(func() { + _ = bridge.Shutdown(context.Background()) + }) + + // Make a request + req, err := http.NewRequestWithContext(t.Context(), "POST", tc.route, strings.NewReader(string(files[fixtureRequest]))) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + req = req.WithContext(aibridge.AsActor(req.Context(), userID, nil)) + rec := httptest.NewRecorder() + bridge.ServeHTTP(rec, req) + require.Equal(t, 200, rec.Code) + + // Check that log files were created + // Parse the request to get the model name + var reqData map[string]any + require.NoError(t, json.Unmarshal(files[fixtureRequest], &reqData)) + model := reqData["model"].(string) + + logDir := filepath.Join(tmpDir, tc.provider, model) + entries, err := os.ReadDir(logDir) + require.NoError(t, err, "log directory should exist") + require.NotEmpty(t, entries, "log directory should contain files") + + // Should have at least one .req.log and one .res.log file + var hasReq, hasRes bool + for _, entry := range entries { + name := entry.Name() + if strings.HasSuffix(name, ".req.log") { + hasReq = true + // Verify the file has content + content, err := os.ReadFile(filepath.Join(logDir, name)) + require.NoError(t, err) + require.NotEmpty(t, content, "request log should have content") + require.Contains(t, string(content), "POST") + } else if strings.HasSuffix(name, ".res.log") { + hasRes = true + // Verify the file has content + content, err := os.ReadFile(filepath.Join(logDir, name)) + require.NoError(t, err) + require.NotEmpty(t, content, "response log should have content") + require.Contains(t, string(content), "200") + } + } + require.True(t, hasReq, "should have request log file") + require.True(t, hasRes, "should have response log file") + }) + } +} + +func TestSanitizeModelName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple model", + input: "gpt-4o", + expected: "gpt-4o", + }, + { + name: "model with slash", + input: "gpt-4o/mini", + expected: "gpt-4o_mini", + }, + { + name: "model with colon", + input: "o1:2024-12-17", + expected: "o1_2024-12-17", + }, + { + name: "model with backslash", + input: "model\\name", + expected: "model_name", + }, + { + name: "model with multiple special chars", + input: "model:name/version?", + expected: "model_name_version_", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := aibridge.SanitizeModelName(tt.input) + require.Equal(t, tt.expected, result) + }) + } +}