Skip to content

Commit a4c07af

Browse files
committed
chore: refactoring, tests
Signed-off-by: Danny Kopping <[email protected]>
1 parent 0452dc8 commit a4c07af

File tree

6 files changed

+207
-20
lines changed

6 files changed

+207
-20
lines changed

bridge_integration_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,8 +1152,8 @@ func createMockMCPSrv(t *testing.T) http.Handler {
11521152
return server.NewStreamableHTTPServer(s)
11531153
}
11541154

1155-
func cfg(url, key string) aibridge.ProviderConfig {
1156-
return aibridge.ProviderConfig{
1155+
func cfg(url, key string) *aibridge.ProviderConfig {
1156+
return &aibridge.ProviderConfig{
11571157
BaseURL: url,
11581158
Key: key,
11591159
}

config.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@ import "sync/atomic"
44

55
type ProviderConfig struct {
66
BaseURL, Key string
7-
// EnableUpstreamLogging enables logging of upstream API requests and responses to /tmp/$provider.log
7+
// UpstreamLoggingDir specifies the base directory for upstream logging.
8+
// If empty, os.TempDir() will be used.
9+
// Logs are written to $UpstreamLoggingDir/$provider/$model/$id.{req,res}.log
10+
UpstreamLoggingDir string
11+
// enableUpstreamLogging enables logging of upstream API requests and responses.
812
enableUpstreamLogging atomic.Bool
913
}
1014

provider_anthropic.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ func newAnthropicClient(logger slog.Logger, cfg *ProviderConfig, id, model strin
104104
opts = append(opts, option.WithBaseURL(cfg.BaseURL))
105105

106106
if cfg.EnableUpstreamLogging() {
107-
if middleware := createLoggingMiddleware(logger, "anthropic", id, model); middleware != nil {
107+
if middleware := createLoggingMiddleware(logger, cfg, ProviderAnthropic, id, model); middleware != nil {
108108
opts = append(opts, option.WithMiddleware(middleware))
109109
}
110110
}

provider_openai.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ func newOpenAIClient(logger slog.Logger, cfg *ProviderConfig, id, model string)
107107
opts = append(opts, option.WithBaseURL(cfg.BaseURL))
108108

109109
if cfg.EnableUpstreamLogging() {
110-
if middleware := createLoggingMiddleware(logger, "openai", id, model); middleware != nil {
110+
if middleware := createLoggingMiddleware(logger, cfg, ProviderOpenAI, id, model); middleware != nil {
111111
opts = append(opts, option.WithMiddleware(middleware))
112112
}
113113
}

request_logger.go

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,20 @@ import (
1313
"cdr.dev/slog"
1414
)
1515

16-
// sanitizeModelName makes a model name safe for use as a directory name.
16+
// SanitizeModelName makes a model name safe for use as a directory name.
1717
// Replaces filesystem-unsafe characters with underscores.
18-
func sanitizeModelName(model string) string {
18+
func SanitizeModelName(model string) string {
19+
repl := "_"
1920
replacer := strings.NewReplacer(
20-
"/", "_",
21-
"\\", "_",
22-
":", "_",
23-
"*", "_",
24-
"?", "_",
25-
"\"", "_",
26-
"<", "_",
27-
">", "_",
28-
"|", "_",
21+
"/", repl,
22+
"\\", repl,
23+
":", repl,
24+
"*", repl,
25+
"?", repl,
26+
"\"", repl,
27+
"<", repl,
28+
">", repl,
29+
"|", repl,
2930
)
3031
return replacer.Replace(model)
3132
}
@@ -65,12 +66,19 @@ func logUpstreamError(logger *log.Logger, id, model string, err error) {
6566
}
6667

6768
// createLoggingMiddleware creates a middleware function that logs requests and responses.
68-
// Logs are written to $TMPDIR/$provider/$model/$id.req.log and $TMPDIR/$provider/$model/$id.res.log
69+
// Logs are written to $baseDir/$provider/$model/$id.req.log and $baseDir/$provider/$model/$id.res.log
70+
// where baseDir is from cfg.UpstreamLoggingDir or os.TempDir() if not specified.
6971
// Returns nil if logging setup fails, logging errors via the provided logger.
70-
func createLoggingMiddleware(logger slog.Logger, provider, id, model string) func(*http.Request, func(*http.Request) (*http.Response, error)) (*http.Response, error) {
72+
func createLoggingMiddleware(logger slog.Logger, cfg *ProviderConfig, provider, id, model string) func(*http.Request, func(*http.Request) (*http.Response, error)) (*http.Response, error) {
7173
ctx := context.Background()
72-
safeModel := sanitizeModelName(model)
73-
logDir := filepath.Join(os.TempDir(), provider, safeModel)
74+
safeModel := SanitizeModelName(model)
75+
76+
baseDir := cfg.UpstreamLoggingDir
77+
if baseDir == "" {
78+
baseDir = os.TempDir()
79+
}
80+
81+
logDir := filepath.Join(baseDir, provider, safeModel)
7482

7583
// Create the directory structure if it doesn't exist
7684
if err := os.MkdirAll(logDir, 0755); err != nil {

request_logger_test.go

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
package aibridge_test
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"net/http"
7+
"net/http/httptest"
8+
"os"
9+
"path/filepath"
10+
"strings"
11+
"testing"
12+
13+
"cdr.dev/slog"
14+
"cdr.dev/slog/sloggers/slogtest"
15+
"github.com/coder/aibridge"
16+
"github.com/coder/aibridge/mcp"
17+
"github.com/stretchr/testify/require"
18+
"golang.org/x/tools/txtar"
19+
)
20+
21+
func TestRequestLogging(t *testing.T) {
22+
t.Parallel()
23+
24+
testCases := []struct {
25+
provider string
26+
fixture []byte
27+
route string
28+
createProvider func(*aibridge.ProviderConfig) aibridge.Provider
29+
}{
30+
{
31+
provider: aibridge.ProviderAnthropic,
32+
fixture: antSimple,
33+
route: "/anthropic/v1/messages",
34+
createProvider: func(cfg *aibridge.ProviderConfig) aibridge.Provider {
35+
return aibridge.NewAnthropicProvider(cfg)
36+
},
37+
},
38+
{
39+
provider: aibridge.ProviderOpenAI,
40+
fixture: oaiSimple,
41+
route: "/openai/v1/chat/completions",
42+
createProvider: func(cfg *aibridge.ProviderConfig) aibridge.Provider {
43+
return aibridge.NewOpenAIProvider(cfg)
44+
},
45+
},
46+
}
47+
48+
for _, tc := range testCases {
49+
t.Run(tc.provider, func(t *testing.T) {
50+
t.Parallel()
51+
52+
logger := slogtest.Make(t, nil).Leveled(slog.LevelDebug)
53+
54+
// Use a temp dir for this test
55+
tmpDir := t.TempDir()
56+
57+
// Parse fixture
58+
arc := txtar.Parse(tc.fixture)
59+
files := filesMap(arc)
60+
require.Contains(t, files, fixtureRequest)
61+
require.Contains(t, files, fixtureNonStreamingResponse)
62+
63+
// Create mock server
64+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
65+
w.Header().Set("Content-Type", "application/json")
66+
w.WriteHeader(http.StatusOK)
67+
_, _ = w.Write(files[fixtureNonStreamingResponse])
68+
}))
69+
t.Cleanup(srv.Close)
70+
71+
cfg := aibridge.ProviderConfig{
72+
BaseURL: srv.URL,
73+
Key: apiKey,
74+
UpstreamLoggingDir: tmpDir,
75+
}
76+
cfg.SetEnableUpstreamLogging(true)
77+
78+
provider := tc.createProvider(&cfg)
79+
client := &mockRecorderClient{}
80+
mcpProxy := mcp.NewServerProxyManager(nil)
81+
82+
bridge, err := aibridge.NewRequestBridge(context.Background(), []aibridge.Provider{provider}, logger, client, mcpProxy)
83+
require.NoError(t, err)
84+
t.Cleanup(func() {
85+
_ = bridge.Shutdown(context.Background())
86+
})
87+
88+
// Make a request
89+
req, err := http.NewRequestWithContext(t.Context(), "POST", tc.route, strings.NewReader(string(files[fixtureRequest])))
90+
require.NoError(t, err)
91+
req.Header.Set("Content-Type", "application/json")
92+
req = req.WithContext(aibridge.AsActor(req.Context(), userID, nil))
93+
rec := httptest.NewRecorder()
94+
bridge.ServeHTTP(rec, req)
95+
require.Equal(t, 200, rec.Code)
96+
97+
// Check that log files were created
98+
// Parse the request to get the model name
99+
var reqData map[string]any
100+
require.NoError(t, json.Unmarshal(files[fixtureRequest], &reqData))
101+
model := reqData["model"].(string)
102+
103+
logDir := filepath.Join(tmpDir, tc.provider, model)
104+
entries, err := os.ReadDir(logDir)
105+
require.NoError(t, err, "log directory should exist")
106+
require.NotEmpty(t, entries, "log directory should contain files")
107+
108+
// Should have at least one .req.log and one .res.log file
109+
var hasReq, hasRes bool
110+
for _, entry := range entries {
111+
name := entry.Name()
112+
if strings.HasSuffix(name, ".req.log") {
113+
hasReq = true
114+
// Verify the file has content
115+
content, err := os.ReadFile(filepath.Join(logDir, name))
116+
require.NoError(t, err)
117+
require.NotEmpty(t, content, "request log should have content")
118+
require.Contains(t, string(content), "POST")
119+
} else if strings.HasSuffix(name, ".res.log") {
120+
hasRes = true
121+
// Verify the file has content
122+
content, err := os.ReadFile(filepath.Join(logDir, name))
123+
require.NoError(t, err)
124+
require.NotEmpty(t, content, "response log should have content")
125+
require.Contains(t, string(content), "200")
126+
}
127+
}
128+
require.True(t, hasReq, "should have request log file")
129+
require.True(t, hasRes, "should have response log file")
130+
})
131+
}
132+
}
133+
134+
func TestSanitizeModelName(t *testing.T) {
135+
t.Parallel()
136+
137+
tests := []struct {
138+
name string
139+
input string
140+
expected string
141+
}{
142+
{
143+
name: "simple model",
144+
input: "gpt-4o",
145+
expected: "gpt-4o",
146+
},
147+
{
148+
name: "model with slash",
149+
input: "gpt-4o/mini",
150+
expected: "gpt-4o_mini",
151+
},
152+
{
153+
name: "model with colon",
154+
input: "o1:2024-12-17",
155+
expected: "o1_2024-12-17",
156+
},
157+
{
158+
name: "model with backslash",
159+
input: "model\\name",
160+
expected: "model_name",
161+
},
162+
{
163+
name: "model with multiple special chars",
164+
input: "model:name/version?",
165+
expected: "model_name_version_",
166+
},
167+
}
168+
169+
for _, tt := range tests {
170+
t.Run(tt.name, func(t *testing.T) {
171+
result := aibridge.SanitizeModelName(tt.input)
172+
require.Equal(t, tt.expected, result)
173+
})
174+
}
175+
}

0 commit comments

Comments
 (0)