Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions bridge_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ func TestAnthropicMessages(t *testing.T) {
recorderClient := &mockRecorderClient{}

logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(cfg(srv.URL, apiKey))}, logger, recorderClient, mcp.NewServerProxyManager(nil))
provider, err := aibridge.NewAnthropicProvider(cfg(srv.URL, apiKey))
require.NoError(t, err)
b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, logger, recorderClient, mcp.NewServerProxyManager(nil))
require.NoError(t, err)

mockSrv := httptest.NewUnstartedServer(b)
Expand Down Expand Up @@ -229,7 +231,9 @@ func TestOpenAIChatCompletions(t *testing.T) {
recorderClient := &mockRecorderClient{}

logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(cfg(srv.URL, apiKey))}, logger, recorderClient, mcp.NewServerProxyManager(nil))
provider, err := aibridge.NewOpenAIProvider(cfg(srv.URL, apiKey))
require.NoError(t, err)
b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, recorderClient, mcp.NewServerProxyManager(nil))
require.NoError(t, err)

mockSrv := httptest.NewUnstartedServer(b)
Expand Down Expand Up @@ -294,7 +298,11 @@ func TestSimple(t *testing.T) {
fixture: antSimple,
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(cfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil))
provider, err := aibridge.NewAnthropicProvider(cfg(addr, apiKey))
if err != nil {
return nil, err
}
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil))
},
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
if streaming {
Expand Down Expand Up @@ -332,7 +340,11 @@ func TestSimple(t *testing.T) {
fixture: oaiSimple,
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(cfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil))
provider, err := aibridge.NewOpenAIProvider(cfg(addr, apiKey))
if err != nil {
return nil, err
}
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil))
},
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
if streaming {
Expand Down Expand Up @@ -470,7 +482,8 @@ func TestFallthrough(t *testing.T) {
fixture: antFallthrough,
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
provider := aibridge.NewAnthropicProvider(cfg(addr, apiKey))
provider, err := aibridge.NewAnthropicProvider(cfg(addr, apiKey))
require.NoError(t, err)
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil))
require.NoError(t, err)
return provider, bridge
Expand All @@ -481,7 +494,8 @@ func TestFallthrough(t *testing.T) {
fixture: oaiFallthrough,
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
provider := aibridge.NewOpenAIProvider(cfg(addr, apiKey))
provider, err := aibridge.NewOpenAIProvider(cfg(addr, apiKey))
require.NoError(t, err)
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil))
require.NoError(t, err)
return provider, bridge
Expand Down Expand Up @@ -586,7 +600,11 @@ func TestAnthropicInjectedTools(t *testing.T) {

configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(cfg(addr, apiKey))}, logger, client, srvProxyMgr)
provider, err := aibridge.NewAnthropicProvider(cfg(addr, apiKey))
if err != nil {
return nil, err
}
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, srvProxyMgr)
}

// Build the requirements & make the assertions which are common to all providers.
Expand Down Expand Up @@ -667,7 +685,11 @@ func TestOpenAIInjectedTools(t *testing.T) {

configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(cfg(addr, apiKey))}, logger, client, srvProxyMgr)
provider, err := aibridge.NewOpenAIProvider(cfg(addr, apiKey))
if err != nil {
return nil, err
}
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, srvProxyMgr)
}

// Build the requirements & make the assertions which are common to all providers.
Expand Down Expand Up @@ -851,7 +873,11 @@ func TestErrorHandling(t *testing.T) {
createRequestFunc: createAnthropicMessagesReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(cfg(addr, apiKey))}, logger, client, srvProxyMgr)
provider, err := aibridge.NewAnthropicProvider(cfg(addr, apiKey))
if err != nil {
return nil, err
}
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, srvProxyMgr)
},
responseHandlerFn: func(streaming bool, resp *http.Response) {
if streaming {
Expand All @@ -876,7 +902,11 @@ func TestErrorHandling(t *testing.T) {
createRequestFunc: createOpenAIChatCompletionsReq,
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(cfg(addr, apiKey))}, logger, client, srvProxyMgr)
provider, err := aibridge.NewOpenAIProvider(cfg(addr, apiKey))
if err != nil {
return nil, err
}
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, srvProxyMgr)
},
responseHandlerFn: func(streaming bool, resp *http.Response) {
if streaming {
Expand Down Expand Up @@ -1152,9 +1182,6 @@ func createMockMCPSrv(t *testing.T) http.Handler {
return server.NewStreamableHTTPServer(s)
}

func cfg(url, key string) aibridge.ProviderConfig {
return aibridge.ProviderConfig{
BaseURL: url,
Key: key,
}
func cfg(url, key string) *aibridge.ProviderConfig {
return aibridge.NewProviderConfig(url, key, "")
}
58 changes: 54 additions & 4 deletions config.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,60 @@
package aibridge

import "go.uber.org/atomic"

type ProviderConfig struct {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot especially given that the baseURL and key won't change during a provider's lifetime within Coder.

Have you considered moving the methods for changing upstream logging settings to the provider rather than having the config be mutable? In Coder for example, we don't change the coderd.Options type after passing it to the API.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have a look.

BaseURL, Key string
baseURL, key atomic.String
upstreamLoggingDir atomic.String
enableUpstreamLogging atomic.Bool
Comment on lines +7 to +8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be two values? An empty string could denote that it's disabled

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An empty upstreamLoggingDir defaults to os.TempDir. We may want to allow configuration of the path in the future, so this is forward-looking.

}

// NewProviderConfig creates a new ProviderConfig with the given values.
func NewProviderConfig(baseURL, key, upstreamLoggingDir string) *ProviderConfig {
cfg := &ProviderConfig{}
cfg.baseURL.Store(baseURL)
cfg.key.Store(key)
cfg.upstreamLoggingDir.Store(upstreamLoggingDir)
return cfg
}

// BaseURL returns the base URL for the provider.
func (c *ProviderConfig) BaseURL() string {
return c.baseURL.Load()
}

// SetBaseURL sets the base URL for the provider.
func (c *ProviderConfig) SetBaseURL(baseURL string) {
c.baseURL.Store(baseURL)
}

// Key returns the API key for the provider.
func (c *ProviderConfig) Key() string {
return c.key.Load()
}

// SetKey sets the API key for the provider.
func (c *ProviderConfig) SetKey(key string) {
c.key.Store(key)
}

// UpstreamLoggingDir returns the base directory for upstream logging.
// If empty, the OS's tempdir will be used.
// Logs are written to $UpstreamLoggingDir/$provider/$model/$interceptionID.{req,res}.log
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

File per interception seems like a lot of files. For short debugging it should be fine but not in general. Maybe instead of directory structure / separate files structured logging could work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Structured logging won't be easy to read because of the large payloads.
You're correct it'll be a lot of files. The point of the runtime enable/disable is so that it can be turned on for short periods of time.
I'll add docs to coder/coder#20414 to indicate this.

func (c *ProviderConfig) UpstreamLoggingDir() string {
return c.upstreamLoggingDir.Load()
}

// SetUpstreamLoggingDir sets the base directory for upstream logging.
func (c *ProviderConfig) SetUpstreamLoggingDir(dir string) {
c.upstreamLoggingDir.Store(dir)
}

// SetEnableUpstreamLogging enables or disables upstream logging at runtime.
func (c *ProviderConfig) SetEnableUpstreamLogging(enabled bool) {
c.enableUpstreamLogging.Store(enabled)
}

type Config struct {
OpenAI ProviderConfig
Anthropic ProviderConfig
// IsUpstreamLoggingEnabled returns whether upstream logging is currently enabled.
func (c *ProviderConfig) IsUpstreamLoggingEnabled() bool {
return c.enableUpstreamLogging.Load()
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ require (
github.com/stretchr/testify v1.10.0
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/sjson v1.2.5
go.uber.org/atomic v1.11.0
go.uber.org/goleak v1.3.0
go.uber.org/mock v0.6.0
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ go.opentelemetry.io/otel/sdk v1.16.0 h1:Z1Ok1YsijYL0CSJpHt4cS3wDDh7p572grzNrBMiM
go.opentelemetry.io/otel/sdk v1.16.0/go.mod h1:tMsIuKXuuIWPBAOrH+eHtvhTL+SntFtXF9QD68aP6p4=
go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qqW2d/s=
go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
Expand Down
4 changes: 2 additions & 2 deletions intercept_anthropic_messages_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ type AnthropicMessagesInterceptionBase struct {
id uuid.UUID
req *MessageNewParamsWrapper

baseURL, key string
logger slog.Logger
cfg *ProviderConfig
logger slog.Logger

recorder Recorder
mcpProxy mcp.ServerProxier
Expand Down
11 changes: 5 additions & 6 deletions intercept_anthropic_messages_blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,11 @@ type AnthropicMessagesBlockingInterception struct {
AnthropicMessagesInterceptionBase
}

func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, baseURL, key string) *AnthropicMessagesBlockingInterception {
func NewAnthropicMessagesBlockingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg *ProviderConfig) *AnthropicMessagesBlockingInterception {
return &AnthropicMessagesBlockingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{
id: id,
req: req,
baseURL: baseURL,
key: key,
id: id,
req: req,
cfg: cfg,
}}
}

Expand Down Expand Up @@ -58,7 +57,7 @@ func (i *AnthropicMessagesBlockingInterception) ProcessRequest(w http.ResponseWr

opts := []option.RequestOption{option.WithRequestTimeout(time.Second * 60)} // TODO: configurable timeout

client := newAnthropicClient(i.baseURL, i.key, opts...)
client := newAnthropicClient(i.logger, i.cfg, i.id.String(), i.Model(), opts...)
messages := i.req.MessageNewParams
logger := i.logger.With(slog.F("model", i.req.Model))

Expand Down
11 changes: 5 additions & 6 deletions intercept_anthropic_messages_streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ type AnthropicMessagesStreamingInterception struct {
AnthropicMessagesInterceptionBase
}

func NewAnthropicMessagesStreamingInterception(id uuid.UUID, req *MessageNewParamsWrapper, baseURL, key string) *AnthropicMessagesStreamingInterception {
func NewAnthropicMessagesStreamingInterception(id uuid.UUID, req *MessageNewParamsWrapper, cfg *ProviderConfig) *AnthropicMessagesStreamingInterception {
return &AnthropicMessagesStreamingInterception{AnthropicMessagesInterceptionBase: AnthropicMessagesInterceptionBase{
id: id,
req: req,
baseURL: baseURL,
key: key,
id: id,
req: req,
cfg: cfg,
}}
}

Expand Down Expand Up @@ -95,7 +94,7 @@ func (i *AnthropicMessagesStreamingInterception) ProcessRequest(w http.ResponseW
_ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes.
}()

client := newAnthropicClient(i.baseURL, i.key)
client := newAnthropicClient(i.logger, i.cfg, i.id.String(), i.Model())
messages := i.req.MessageNewParams

// Accumulate usage across the entire streaming interaction (including tool reinvocations).
Expand Down
4 changes: 2 additions & 2 deletions intercept_openai_chat_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ type OpenAIChatInterceptionBase struct {
id uuid.UUID
req *ChatCompletionNewParamsWrapper

baseURL, key string
logger slog.Logger
cfg *ProviderConfig
logger slog.Logger

recorder Recorder
mcpProxy mcp.ServerProxier
Expand Down
11 changes: 5 additions & 6 deletions intercept_openai_chat_blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ type OpenAIBlockingChatInterception struct {
OpenAIChatInterceptionBase
}

func NewOpenAIBlockingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string) *OpenAIBlockingChatInterception {
func NewOpenAIBlockingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg *ProviderConfig) *OpenAIBlockingChatInterception {
return &OpenAIBlockingChatInterception{OpenAIChatInterceptionBase: OpenAIChatInterceptionBase{
id: id,
req: req,
baseURL: baseURL,
key: key,
id: id,
req: req,
cfg: cfg,
}}
}

Expand All @@ -42,7 +41,7 @@ func (i *OpenAIBlockingChatInterception) ProcessRequest(w http.ResponseWriter, r
}

ctx := r.Context()
client := newOpenAIClient(i.baseURL, i.key)
client := newOpenAIClient(i.logger, i.cfg, i.id.String(), i.Model())
logger := i.logger.With(slog.F("model", i.req.Model))

var (
Expand Down
11 changes: 5 additions & 6 deletions intercept_openai_chat_streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@ type OpenAIStreamingChatInterception struct {
OpenAIChatInterceptionBase
}

func NewOpenAIStreamingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, baseURL, key string) *OpenAIStreamingChatInterception {
func NewOpenAIStreamingChatInterception(id uuid.UUID, req *ChatCompletionNewParamsWrapper, cfg *ProviderConfig) *OpenAIStreamingChatInterception {
return &OpenAIStreamingChatInterception{OpenAIChatInterceptionBase: OpenAIChatInterceptionBase{
id: id,
req: req,
baseURL: baseURL,
key: key,
id: id,
req: req,
cfg: cfg,
}}
}

Expand Down Expand Up @@ -65,7 +64,7 @@ func (i *OpenAIStreamingChatInterception) ProcessRequest(w http.ResponseWriter,
defer cancel()
r = r.WithContext(ctx) // Rewire context for SSE cancellation.

client := newOpenAIClient(i.baseURL, i.key)
client := newOpenAIClient(i.logger, i.cfg, i.id.String(), i.Model())
logger := i.logger.With(slog.F("model", i.req.Model))

streamCtx, streamCancel := context.WithCancelCause(ctx)
Expand Down
Loading