Skip to content

Commit 6f25879

Browse files
committed
chore: ensure all configs are concurrency-safe
Signed-off-by: Danny Kopping <[email protected]>
1 parent 2de0b07 commit 6f25879

File tree

8 files changed

+127
-55
lines changed

8 files changed

+127
-55
lines changed

bridge_integration_test.go

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ func TestAnthropicMessages(t *testing.T) {
126126
recorderClient := &mockRecorderClient{}
127127

128128
logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug)
129-
b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{aibridge.NewAnthropicProvider(cfg(srv.URL, apiKey))}, logger, recorderClient, mcp.NewServerProxyManager(nil))
129+
provider, err := aibridge.NewAnthropicProvider(cfg(srv.URL, apiKey))
130+
require.NoError(t, err)
131+
b, err := aibridge.NewRequestBridge(ctx, []aibridge.Provider{provider}, logger, recorderClient, mcp.NewServerProxyManager(nil))
130132
require.NoError(t, err)
131133

132134
mockSrv := httptest.NewUnstartedServer(b)
@@ -229,7 +231,9 @@ func TestOpenAIChatCompletions(t *testing.T) {
229231
recorderClient := &mockRecorderClient{}
230232

231233
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
232-
b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(cfg(srv.URL, apiKey))}, logger, recorderClient, mcp.NewServerProxyManager(nil))
234+
provider, err := aibridge.NewOpenAIProvider(cfg(srv.URL, apiKey))
235+
require.NoError(t, err)
236+
b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, recorderClient, mcp.NewServerProxyManager(nil))
233237
require.NoError(t, err)
234238

235239
mockSrv := httptest.NewUnstartedServer(b)
@@ -294,7 +298,11 @@ func TestSimple(t *testing.T) {
294298
fixture: antSimple,
295299
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
296300
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
297-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(cfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil))
301+
provider, err := aibridge.NewAnthropicProvider(cfg(addr, apiKey))
302+
if err != nil {
303+
return nil, err
304+
}
305+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil))
298306
},
299307
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
300308
if streaming {
@@ -332,7 +340,11 @@ func TestSimple(t *testing.T) {
332340
fixture: oaiSimple,
333341
configureFunc: func(addr string, client aibridge.Recorder) (*aibridge.RequestBridge, error) {
334342
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
335-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(cfg(addr, apiKey))}, logger, client, mcp.NewServerProxyManager(nil))
343+
provider, err := aibridge.NewOpenAIProvider(cfg(addr, apiKey))
344+
if err != nil {
345+
return nil, err
346+
}
347+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil))
336348
},
337349
getResponseIDFunc: func(streaming bool, resp *http.Response) (string, error) {
338350
if streaming {
@@ -470,7 +482,8 @@ func TestFallthrough(t *testing.T) {
470482
fixture: antFallthrough,
471483
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
472484
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
473-
provider := aibridge.NewAnthropicProvider(cfg(addr, apiKey))
485+
provider, err := aibridge.NewAnthropicProvider(cfg(addr, apiKey))
486+
require.NoError(t, err)
474487
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil))
475488
require.NoError(t, err)
476489
return provider, bridge
@@ -481,7 +494,8 @@ func TestFallthrough(t *testing.T) {
481494
fixture: oaiFallthrough,
482495
configureFunc: func(addr string, client aibridge.Recorder) (aibridge.Provider, *aibridge.RequestBridge) {
483496
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
484-
provider := aibridge.NewOpenAIProvider(cfg(addr, apiKey))
497+
provider, err := aibridge.NewOpenAIProvider(cfg(addr, apiKey))
498+
require.NoError(t, err)
485499
bridge, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, mcp.NewServerProxyManager(nil))
486500
require.NoError(t, err)
487501
return provider, bridge
@@ -586,7 +600,11 @@ func TestAnthropicInjectedTools(t *testing.T) {
586600

587601
configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
588602
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
589-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(cfg(addr, apiKey))}, logger, client, srvProxyMgr)
603+
provider, err := aibridge.NewAnthropicProvider(cfg(addr, apiKey))
604+
if err != nil {
605+
return nil, err
606+
}
607+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, srvProxyMgr)
590608
}
591609

592610
// Build the requirements & make the assertions which are common to all providers.
@@ -667,7 +685,11 @@ func TestOpenAIInjectedTools(t *testing.T) {
667685

668686
configureFn := func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
669687
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
670-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(cfg(addr, apiKey))}, logger, client, srvProxyMgr)
688+
provider, err := aibridge.NewOpenAIProvider(cfg(addr, apiKey))
689+
if err != nil {
690+
return nil, err
691+
}
692+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, srvProxyMgr)
671693
}
672694

673695
// Build the requirements & make the assertions which are common to all providers.
@@ -851,7 +873,11 @@ func TestErrorHandling(t *testing.T) {
851873
createRequestFunc: createAnthropicMessagesReq,
852874
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
853875
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
854-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewAnthropicProvider(cfg(addr, apiKey))}, logger, client, srvProxyMgr)
876+
provider, err := aibridge.NewAnthropicProvider(cfg(addr, apiKey))
877+
if err != nil {
878+
return nil, err
879+
}
880+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, srvProxyMgr)
855881
},
856882
responseHandlerFn: func(streaming bool, resp *http.Response) {
857883
if streaming {
@@ -876,7 +902,11 @@ func TestErrorHandling(t *testing.T) {
876902
createRequestFunc: createOpenAIChatCompletionsReq,
877903
configureFunc: func(addr string, client aibridge.Recorder, srvProxyMgr *mcp.ServerProxyManager) (*aibridge.RequestBridge, error) {
878904
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
879-
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{aibridge.NewOpenAIProvider(cfg(addr, apiKey))}, logger, client, srvProxyMgr)
905+
provider, err := aibridge.NewOpenAIProvider(cfg(addr, apiKey))
906+
if err != nil {
907+
return nil, err
908+
}
909+
return aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, logger, client, srvProxyMgr)
880910
},
881911
responseHandlerFn: func(streaming bool, resp *http.Response) {
882912
if streaming {
@@ -1153,8 +1183,5 @@ func createMockMCPSrv(t *testing.T) http.Handler {
11531183
}
11541184

11551185
func cfg(url, key string) *aibridge.ProviderConfig {
1156-
return &aibridge.ProviderConfig{
1157-
BaseURL: url,
1158-
Key: key,
1159-
}
1186+
return aibridge.NewProviderConfig(url, key, "")
11601187
}

config.go

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,60 @@
11
package aibridge
22

3-
import "sync/atomic"
3+
import "go.uber.org/atomic"
44

55
type ProviderConfig struct {
6-
BaseURL, Key string
7-
// UpstreamLoggingDir specifies the base directory for upstream logging.
8-
// If empty, os.TempDir() will be used.
9-
// Logs are written to $UpstreamLoggingDir/$provider/$model/$id.{req,res}.log
10-
UpstreamLoggingDir string
11-
// enableUpstreamLogging enables logging of upstream API requests and responses.
6+
baseURL, key atomic.String
7+
upstreamLoggingDir atomic.String
128
enableUpstreamLogging atomic.Bool
139
}
1410

11+
// NewProviderConfig creates a new ProviderConfig with the given values.
12+
func NewProviderConfig(baseURL, key, upstreamLoggingDir string) *ProviderConfig {
13+
cfg := &ProviderConfig{}
14+
cfg.baseURL.Store(baseURL)
15+
cfg.key.Store(key)
16+
cfg.upstreamLoggingDir.Store(upstreamLoggingDir)
17+
return cfg
18+
}
19+
20+
// BaseURL returns the base URL for the provider.
21+
func (c *ProviderConfig) BaseURL() string {
22+
return c.baseURL.Load()
23+
}
24+
25+
// SetBaseURL sets the base URL for the provider.
26+
func (c *ProviderConfig) SetBaseURL(baseURL string) {
27+
c.baseURL.Store(baseURL)
28+
}
29+
30+
// Key returns the API key for the provider.
31+
func (c *ProviderConfig) Key() string {
32+
return c.key.Load()
33+
}
34+
35+
// SetKey sets the API key for the provider.
36+
func (c *ProviderConfig) SetKey(key string) {
37+
c.key.Store(key)
38+
}
39+
40+
// UpstreamLoggingDir returns the base directory for upstream logging.
41+
// If empty, the OS's tempdir will be used.
42+
// Logs are written to $UpstreamLoggingDir/$provider/$model/$interceptionID.{req,res}.log
43+
func (c *ProviderConfig) UpstreamLoggingDir() string {
44+
return c.upstreamLoggingDir.Load()
45+
}
46+
47+
// SetUpstreamLoggingDir sets the base directory for upstream logging.
48+
func (c *ProviderConfig) SetUpstreamLoggingDir(dir string) {
49+
c.upstreamLoggingDir.Store(dir)
50+
}
51+
1552
// SetEnableUpstreamLogging enables or disables upstream logging at runtime.
1653
func (c *ProviderConfig) SetEnableUpstreamLogging(enabled bool) {
1754
c.enableUpstreamLogging.Store(enabled)
1855
}
1956

20-
// EnableUpstreamLogging returns whether upstream logging is currently enabled.
21-
func (c *ProviderConfig) EnableUpstreamLogging() bool {
57+
// IsUpstreamLoggingEnabled returns whether upstream logging is currently enabled.
58+
func (c *ProviderConfig) IsUpstreamLoggingEnabled() bool {
2259
return c.enableUpstreamLogging.Load()
2360
}

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ require (
1111
github.com/stretchr/testify v1.10.0
1212
github.com/tidwall/gjson v1.18.0 // indirect
1313
github.com/tidwall/sjson v1.2.5 // indirect
14+
go.uber.org/atomic v1.11.0
1415
go.uber.org/goleak v1.3.0
1516
go.uber.org/mock v0.6.0
1617
golang.org/x/exp v0.0.0-20250819193227-8b4c13bb791b

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ go.opentelemetry.io/otel/sdk v1.16.0 h1:Z1Ok1YsijYL0CSJpHt4cS3wDDh7p572grzNrBMiM
9494
go.opentelemetry.io/otel/sdk v1.16.0/go.mod h1:tMsIuKXuuIWPBAOrH+eHtvhTL+SntFtXF9QD68aP6p4=
9595
go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qqW2d/s=
9696
go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck=
97+
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
98+
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
9799
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
98100
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
99101
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=

provider_anthropic.go

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,21 @@ const (
2929
routeMessages = "/anthropic/v1/messages" // https://docs.anthropic.com/en/api/messages
3030
)
3131

32-
func NewAnthropicProvider(cfg *ProviderConfig) *AnthropicProvider {
33-
if cfg.BaseURL == "" {
34-
cfg.BaseURL = "https://api.anthropic.com/"
32+
func NewAnthropicProvider(cfg *ProviderConfig) (*AnthropicProvider, error) {
33+
if cfg == nil {
34+
return nil, fmt.Errorf("ProviderConfig cannot be nil")
3535
}
36-
if cfg.Key == "" {
37-
cfg.Key = os.Getenv("ANTHROPIC_API_KEY")
36+
37+
if cfg.BaseURL() == "" {
38+
cfg.SetBaseURL("https://api.anthropic.com/")
39+
}
40+
if cfg.Key() == "" {
41+
cfg.SetKey(os.Getenv("ANTHROPIC_API_KEY"))
3842
}
3943

4044
return &AnthropicProvider{
4145
cfg: cfg,
42-
}
46+
}, nil
4347
}
4448

4549
func (p *AnthropicProvider) Name() string {
@@ -84,7 +88,7 @@ func (p *AnthropicProvider) CreateInterceptor(w http.ResponseWriter, r *http.Req
8488
}
8589

8690
func (p *AnthropicProvider) BaseURL() string {
87-
return p.cfg.BaseURL
91+
return p.cfg.BaseURL()
8892
}
8993

9094
func (p *AnthropicProvider) AuthHeader() string {
@@ -96,14 +100,14 @@ func (p *AnthropicProvider) InjectAuthHeader(headers *http.Header) {
96100
headers = &http.Header{}
97101
}
98102

99-
headers.Set(p.AuthHeader(), p.cfg.Key)
103+
headers.Set(p.AuthHeader(), p.cfg.Key())
100104
}
101105

102106
func newAnthropicClient(logger slog.Logger, cfg *ProviderConfig, id, model string, opts ...option.RequestOption) anthropic.Client {
103-
opts = append(opts, option.WithAPIKey(cfg.Key))
104-
opts = append(opts, option.WithBaseURL(cfg.BaseURL))
107+
opts = append(opts, option.WithAPIKey(cfg.Key()))
108+
opts = append(opts, option.WithBaseURL(cfg.BaseURL()))
105109

106-
if cfg.EnableUpstreamLogging() {
110+
if cfg.IsUpstreamLoggingEnabled() {
107111
if middleware := createLoggingMiddleware(logger, cfg, ProviderAnthropic, id, model); middleware != nil {
108112
opts = append(opts, option.WithMiddleware(middleware))
109113
}

provider_openai.go

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,22 @@ const (
2626
routeChatCompletions = "/openai/v1/chat/completions" // https://platform.openai.com/docs/api-reference/chat
2727
)
2828

29-
func NewOpenAIProvider(cfg *ProviderConfig) *OpenAIProvider {
30-
if cfg.BaseURL == "" {
31-
cfg.BaseURL = "https://api.openai.com/v1/"
29+
func NewOpenAIProvider(cfg *ProviderConfig) (*OpenAIProvider, error) {
30+
if cfg == nil {
31+
return nil, fmt.Errorf("ProviderConfig cannot be nil")
3232
}
3333

34-
if cfg.Key == "" {
35-
cfg.Key = os.Getenv("OPENAI_API_KEY")
34+
if cfg.BaseURL() == "" {
35+
cfg.SetBaseURL("https://api.openai.com/v1/")
36+
}
37+
38+
if cfg.Key() == "" {
39+
cfg.SetKey(os.Getenv("OPENAI_API_KEY"))
3640
}
3741

3842
return &OpenAIProvider{
3943
cfg: cfg,
40-
}
44+
}, nil
4145
}
4246

4347
func (p *OpenAIProvider) Name() string {
@@ -86,7 +90,7 @@ func (p *OpenAIProvider) CreateInterceptor(w http.ResponseWriter, r *http.Reques
8690
}
8791

8892
func (p *OpenAIProvider) BaseURL() string {
89-
return p.cfg.BaseURL
93+
return p.cfg.BaseURL()
9094
}
9195

9296
func (p *OpenAIProvider) AuthHeader() string {
@@ -98,15 +102,15 @@ func (p *OpenAIProvider) InjectAuthHeader(headers *http.Header) {
98102
headers = &http.Header{}
99103
}
100104

101-
headers.Set(p.AuthHeader(), "Bearer "+p.cfg.Key)
105+
headers.Set(p.AuthHeader(), "Bearer "+p.cfg.Key())
102106
}
103107

104108
func newOpenAIClient(logger slog.Logger, cfg *ProviderConfig, id, model string) openai.Client {
105109
var opts []option.RequestOption
106-
opts = append(opts, option.WithAPIKey(cfg.Key))
107-
opts = append(opts, option.WithBaseURL(cfg.BaseURL))
110+
opts = append(opts, option.WithAPIKey(cfg.Key()))
111+
opts = append(opts, option.WithBaseURL(cfg.BaseURL()))
108112

109-
if cfg.EnableUpstreamLogging() {
113+
if cfg.IsUpstreamLoggingEnabled() {
110114
if middleware := createLoggingMiddleware(logger, cfg, ProviderOpenAI, id, model); middleware != nil {
111115
opts = append(opts, option.WithMiddleware(middleware))
112116
}

request_logger.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func createLoggingMiddleware(logger slog.Logger, cfg *ProviderConfig, provider,
7373
ctx := context.Background()
7474
safeModel := SanitizeModelName(model)
7575

76-
baseDir := cfg.UpstreamLoggingDir
76+
baseDir := cfg.UpstreamLoggingDir()
7777
if baseDir == "" {
7878
baseDir = os.TempDir()
7979
}

request_logger_test.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,21 @@ func TestRequestLogging(t *testing.T) {
2525
provider string
2626
fixture []byte
2727
route string
28-
createProvider func(*aibridge.ProviderConfig) aibridge.Provider
28+
createProvider func(*aibridge.ProviderConfig) (aibridge.Provider, error)
2929
}{
3030
{
3131
provider: aibridge.ProviderAnthropic,
3232
fixture: antSimple,
3333
route: "/anthropic/v1/messages",
34-
createProvider: func(cfg *aibridge.ProviderConfig) aibridge.Provider {
34+
createProvider: func(cfg *aibridge.ProviderConfig) (aibridge.Provider, error) {
3535
return aibridge.NewAnthropicProvider(cfg)
3636
},
3737
},
3838
{
3939
provider: aibridge.ProviderOpenAI,
4040
fixture: oaiSimple,
4141
route: "/openai/v1/chat/completions",
42-
createProvider: func(cfg *aibridge.ProviderConfig) aibridge.Provider {
42+
createProvider: func(cfg *aibridge.ProviderConfig) (aibridge.Provider, error) {
4343
return aibridge.NewOpenAIProvider(cfg)
4444
},
4545
},
@@ -68,14 +68,11 @@ func TestRequestLogging(t *testing.T) {
6868
}))
6969
t.Cleanup(srv.Close)
7070

71-
cfg := aibridge.ProviderConfig{
72-
BaseURL: srv.URL,
73-
Key: apiKey,
74-
UpstreamLoggingDir: tmpDir,
75-
}
71+
cfg := aibridge.NewProviderConfig(srv.URL, apiKey, tmpDir)
7672
cfg.SetEnableUpstreamLogging(true)
7773

78-
provider := tc.createProvider(&cfg)
74+
provider, err := tc.createProvider(cfg)
75+
require.NoError(t, err)
7976
client := &mockRecorderClient{}
8077
mcpProxy := mcp.NewServerProxyManager(nil)
8178

0 commit comments

Comments
 (0)