Skip to content

Commit 0a02e95

Browse files
committed
🤖 fix(proxy): address provider-scoped routing review issues
What: - add a dedicated ProviderProxyHandler for /provider/{id}/... requests without changing the existing ProjectProxyHandler construction or flow - wire provider-scoped requests through separate server/static dispatch paths while leaving the original project handler entrypoint intact - keep the provider-scope safety fixes in proxy/router behavior: invalid provider header rejection and provider-scoped router matching that bypasses project custom-route selection - add regression coverage for provider handler dispatch, provider-scoped router behavior, and the Playwright provider route flow Why: - keep provider scope as an independent proxy surface while minimizing churn in the pre-existing project proxy implementation - reduce review noise by isolating the new behavior to provider-specific wiring plus the minimal routing safeguards needed for correctness - preserve the behavioral fixes needed for provider-scoped requests without rewriting project handler logic Tests: - go test ./internal/handler ./internal/router ./tests/e2e/... (pass)
1 parent 324bae9 commit 0a02e95

File tree

11 files changed

+430
-72
lines changed

11 files changed

+430
-72
lines changed

cmd/maxx/main.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ func main() {
376376
// Use already-created cached project repository for project proxy handler
377377
modelsHandler := handler.NewModelsHandler(responseModelRepo, cachedProviderRepo, cachedModelMappingRepo)
378378
projectProxyHandler := handler.NewProjectProxyHandler(proxyHandler, modelsHandler, cachedProjectRepo, cachedProviderRepo)
379+
providerProxyHandler := handler.NewProviderProxyHandler(proxyHandler, modelsHandler, cachedProviderRepo)
379380

380381
// Setup routes
381382
mux := http.NewServeMux()
@@ -422,7 +423,7 @@ func main() {
422423

423424
// Serve static files (Web UI) with project proxy support - must be last (default route)
424425
staticHandler := handler.NewStaticHandler()
425-
combinedHandler := handler.NewCombinedHandler(projectProxyHandler, staticHandler)
426+
combinedHandler := handler.NewCombinedHandler(projectProxyHandler, providerProxyHandler, staticHandler)
426427
mux.Handle("/", combinedHandler)
427428

428429
// Wrap with logging middleware
@@ -531,6 +532,7 @@ func main() {
531532
log.Printf(" Codex: http://localhost%s/v1/responses", *addr)
532533
log.Printf(" Gemini: http://localhost%s/v1beta/models/{model}:generateContent", *addr)
533534
log.Printf("Project proxy: http://localhost%s/project/{project-slug}/v1/messages (etc.)", *addr)
535+
log.Printf("Provider proxy: http://localhost%s/provider/{provider-id}/v1/messages (etc.)", *addr)
534536

535537
go func() {
536538
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {

internal/core/database.go

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -73,27 +73,28 @@ type DatabaseRepos struct {
7373

7474
// ServerComponents 包含服务器运行所需的所有组件
7575
type ServerComponents struct {
76-
Router *router.Router
77-
WebSocketHub *handler.WebSocketHub
78-
WailsBroadcaster *event.WailsBroadcaster
79-
Executor *executor.Executor
80-
ClientAdapter *client.Adapter
81-
AdminService *service.AdminService
82-
ProxyHandler *handler.ProxyHandler
83-
ModelsHandler *handler.ModelsHandler
84-
AdminHandler *handler.AdminHandler
85-
AntigravityHandler *handler.AntigravityHandler
86-
KiroHandler *handler.KiroHandler
87-
CodexHandler *handler.CodexHandler
88-
CodexOAuthServer *CodexOAuthServer
89-
ClaudeHandler *handler.ClaudeHandler
90-
ClaudeOAuthServer *ClaudeOAuthServer
91-
ProjectProxyHandler *handler.ProjectProxyHandler
92-
RequestTracker *RequestTracker
93-
PprofManager *PprofManager
94-
AuthMiddleware *handler.AuthMiddleware
95-
AuthHandler *handler.AuthHandler
96-
BackupService *service.BackupService
76+
Router *router.Router
77+
WebSocketHub *handler.WebSocketHub
78+
WailsBroadcaster *event.WailsBroadcaster
79+
Executor *executor.Executor
80+
ClientAdapter *client.Adapter
81+
AdminService *service.AdminService
82+
ProxyHandler *handler.ProxyHandler
83+
ModelsHandler *handler.ModelsHandler
84+
AdminHandler *handler.AdminHandler
85+
AntigravityHandler *handler.AntigravityHandler
86+
KiroHandler *handler.KiroHandler
87+
CodexHandler *handler.CodexHandler
88+
CodexOAuthServer *CodexOAuthServer
89+
ClaudeHandler *handler.ClaudeHandler
90+
ClaudeOAuthServer *ClaudeOAuthServer
91+
ProjectProxyHandler *handler.ProjectProxyHandler
92+
ProviderProxyHandler *handler.ProviderProxyHandler
93+
RequestTracker *RequestTracker
94+
PprofManager *PprofManager
95+
AuthMiddleware *handler.AuthMiddleware
96+
AuthHandler *handler.AuthHandler
97+
BackupService *service.BackupService
9798
}
9899

99100
// InitializeDatabase 初始化数据库和所有仓库
@@ -409,33 +410,35 @@ func InitializeServerComponents(
409410
claudeOAuthServer := NewClaudeOAuthServer(claudeHandler)
410411
claudeHandler.SetOAuthServer(claudeOAuthServer)
411412
projectProxyHandler := handler.NewProjectProxyHandler(proxyHandler, modelsHandler, repos.CachedProjectRepo, repos.CachedProviderRepo)
413+
providerProxyHandler := handler.NewProviderProxyHandler(proxyHandler, modelsHandler, repos.CachedProviderRepo)
412414

413415
log.Printf("[Core] Creating request tracker for graceful shutdown")
414416
requestTracker := NewRequestTracker()
415417
proxyHandler.SetRequestTracker(requestTracker)
416418

417419
components := &ServerComponents{
418-
Router: r,
419-
WebSocketHub: wsHub,
420-
WailsBroadcaster: wailsBroadcaster,
421-
Executor: exec,
422-
ClientAdapter: clientAdapter,
423-
AdminService: adminService,
424-
ProxyHandler: proxyHandler,
425-
ModelsHandler: modelsHandler,
426-
AdminHandler: adminHandler,
427-
AntigravityHandler: antigravityHandler,
428-
KiroHandler: kiroHandler,
429-
CodexHandler: codexHandler,
430-
CodexOAuthServer: codexOAuthServer,
431-
ClaudeHandler: claudeHandler,
432-
ClaudeOAuthServer: claudeOAuthServer,
433-
ProjectProxyHandler: projectProxyHandler,
434-
RequestTracker: requestTracker,
435-
PprofManager: pprofMgr,
436-
AuthMiddleware: authMiddleware,
437-
AuthHandler: authHandler,
438-
BackupService: backupService,
420+
Router: r,
421+
WebSocketHub: wsHub,
422+
WailsBroadcaster: wailsBroadcaster,
423+
Executor: exec,
424+
ClientAdapter: clientAdapter,
425+
AdminService: adminService,
426+
ProxyHandler: proxyHandler,
427+
ModelsHandler: modelsHandler,
428+
AdminHandler: adminHandler,
429+
AntigravityHandler: antigravityHandler,
430+
KiroHandler: kiroHandler,
431+
CodexHandler: codexHandler,
432+
CodexOAuthServer: codexOAuthServer,
433+
ClaudeHandler: claudeHandler,
434+
ClaudeOAuthServer: claudeOAuthServer,
435+
ProjectProxyHandler: projectProxyHandler,
436+
ProviderProxyHandler: providerProxyHandler,
437+
RequestTracker: requestTracker,
438+
PprofManager: pprofMgr,
439+
AuthMiddleware: authMiddleware,
440+
AuthHandler: authHandler,
441+
BackupService: backupService,
439442
}
440443

441444
log.Printf("[Core] Server components initialized successfully")

internal/core/server.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,12 @@ func (s *ManagedServer) setupRoutes() *http.ServeMux {
106106

107107
if s.config.ServeStatic {
108108
staticHandler := handler.NewStaticHandler()
109-
combinedHandler := handler.NewCombinedHandler(components.ProjectProxyHandler, staticHandler)
109+
combinedHandler := handler.NewCombinedHandler(components.ProjectProxyHandler, components.ProviderProxyHandler, staticHandler)
110110
mux.Handle("/", combinedHandler)
111111
log.Printf("[Server] Static file serving enabled")
112112
} else {
113-
mux.Handle("/", components.ProjectProxyHandler)
113+
mux.Handle("/project/", components.ProjectProxyHandler)
114+
mux.Handle("/provider/", components.ProviderProxyHandler)
114115
log.Printf("[Server] Static file serving disabled (Wails mode)")
115116
}
116117

internal/handler/provider_proxy.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package handler
2+
3+
import (
4+
"log"
5+
"net/http"
6+
"strconv"
7+
"strings"
8+
9+
maxxctx "github.com/awsl-project/maxx/internal/context"
10+
"github.com/awsl-project/maxx/internal/repository"
11+
)
12+
13+
// ProviderProxyHandler wraps ProxyHandler to handle provider-scoped proxy requests
14+
// like /provider/{id}/v1/messages.
15+
type ProviderProxyHandler struct {
16+
proxyHandler *ProxyHandler
17+
modelsHandler *ModelsHandler
18+
providerRepo repository.ProviderRepository
19+
}
20+
21+
// NewProviderProxyHandler creates a new provider-scoped proxy handler.
22+
func NewProviderProxyHandler(
23+
proxyHandler *ProxyHandler,
24+
modelsHandler *ModelsHandler,
25+
providerRepo repository.ProviderRepository,
26+
) *ProviderProxyHandler {
27+
return &ProviderProxyHandler{
28+
proxyHandler: proxyHandler,
29+
modelsHandler: modelsHandler,
30+
providerRepo: providerRepo,
31+
}
32+
}
33+
34+
// ServeHTTP handles provider-scoped proxy requests.
35+
func (h *ProviderProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
36+
providerValue, apiPath, ok := h.parseScopedPath(r.URL.Path)
37+
if !ok {
38+
writeError(w, http.StatusNotFound, "invalid provider proxy path")
39+
return
40+
}
41+
42+
providerID, err := strconv.ParseUint(providerValue, 10, 64)
43+
if err != nil || providerID == 0 {
44+
writeError(w, http.StatusBadRequest, "invalid provider id")
45+
return
46+
}
47+
48+
tenantID := maxxctx.GetTenantID(r.Context())
49+
provider, err := h.providerRepo.GetByID(tenantID, providerID)
50+
if err != nil || provider == nil {
51+
log.Printf("[ProviderProxy] Provider not found for id: %s", providerValue)
52+
writeError(w, http.StatusNotFound, "provider not found")
53+
return
54+
}
55+
56+
log.Printf("[ProviderProxy] Routing request through provider: %s (ID: %d)", provider.Name, provider.ID)
57+
r.Header.Set("X-Maxx-Provider-ID", strings.TrimSpace(itoa(provider.ID)))
58+
r.URL.Path = apiPath
59+
if apiPath == "/v1/models" {
60+
h.modelsHandler.ServeHTTP(w, r)
61+
return
62+
}
63+
64+
h.proxyHandler.ServeHTTP(w, r)
65+
}
66+
67+
// parseScopedPath extracts provider ID and API path.
68+
// Input: /provider/1/v1/messages
69+
func (h *ProviderProxyHandler) parseScopedPath(path string) (providerID, apiPath string, ok bool) {
70+
if !strings.HasPrefix(path, "/provider/") {
71+
return "", "", false
72+
}
73+
parts := strings.SplitN(strings.TrimPrefix(path, "/provider/"), "/", 2)
74+
if len(parts) < 2 {
75+
return "", "", false
76+
}
77+
78+
trimmed := strings.TrimSpace(parts[0])
79+
if trimmed == "" {
80+
return "", "", false
81+
}
82+
83+
apiPath = "/" + parts[1]
84+
if !isValidProviderAPIPath(apiPath) {
85+
return "", "", false
86+
}
87+
return trimmed, apiPath, true
88+
}
89+
90+
func isValidProviderAPIPath(path string) bool {
91+
return matchesProviderEndpointPath(path, "/v1/messages") ||
92+
matchesProviderEndpointPath(path, "/v1/chat/completions") ||
93+
matchesProviderEndpointPath(path, "/responses") ||
94+
matchesProviderEndpointPath(path, "/v1/responses") ||
95+
matchesProviderEndpointPath(path, "/v1/models") ||
96+
matchesProviderEndpointPath(path, "/v1beta/models")
97+
}
98+
99+
func matchesProviderEndpointPath(path, endpoint string) bool {
100+
return path == endpoint || strings.HasPrefix(path, endpoint+"/")
101+
}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package handler
2+
3+
import "testing"
4+
5+
func TestParseProviderScopedPath(t *testing.T) {
6+
h := &ProviderProxyHandler{}
7+
providerID, apiPath, ok := h.parseScopedPath("/provider/1/v1/chat/completions")
8+
if !ok {
9+
t.Fatal("expected provider path to parse")
10+
}
11+
if providerID != "1" {
12+
t.Fatalf("providerID = %q, want 1", providerID)
13+
}
14+
if apiPath != "/v1/chat/completions" {
15+
t.Fatalf("apiPath = %q, want /v1/chat/completions", apiPath)
16+
}
17+
}
18+
19+
func TestParseProviderScopedPath_TrimsScopeValue(t *testing.T) {
20+
h := &ProviderProxyHandler{}
21+
providerID, apiPath, ok := h.parseScopedPath("/provider/ 1 /v1/messages")
22+
if !ok {
23+
t.Fatal("expected provider path to parse")
24+
}
25+
if providerID != "1" {
26+
t.Fatalf("providerID = %q, want 1", providerID)
27+
}
28+
if apiPath != "/v1/messages" {
29+
t.Fatalf("apiPath = %q, want /v1/messages", apiPath)
30+
}
31+
}
32+
33+
func TestIsValidProviderAPIPath_AllowsExactAndSubpathsOnly(t *testing.T) {
34+
valid := []string{
35+
"/v1/messages",
36+
"/v1/messages/stream",
37+
"/v1/chat/completions",
38+
"/v1/chat/completions/extra",
39+
"/responses",
40+
"/responses/items",
41+
"/v1/responses",
42+
"/v1/responses/abc",
43+
"/v1/models",
44+
"/v1/models/list",
45+
"/v1beta/models",
46+
"/v1beta/models/gemini-2.5-pro",
47+
}
48+
for _, path := range valid {
49+
if !isValidProviderAPIPath(path) {
50+
t.Fatalf("expected %q to be valid", path)
51+
}
52+
}
53+
54+
invalid := []string{
55+
"/v1/messages-debug",
56+
"/v1/chat/completionsXYZ",
57+
"/responses123",
58+
"/v1/responsesXYZ",
59+
"/v1/models-debug",
60+
"/v1beta/modelsX",
61+
}
62+
for _, path := range invalid {
63+
if isValidProviderAPIPath(path) {
64+
t.Fatalf("expected %q to be invalid", path)
65+
}
66+
}
67+
}
68+
69+
func TestIsProviderProxyPath(t *testing.T) {
70+
if !isProviderProxyPath("/provider/1/v1/messages") {
71+
t.Fatal("expected provider path to be detected")
72+
}
73+
if isProviderProxyPath("/project/demo/v1/messages") {
74+
t.Fatal("did not expect project path to be detected as provider path")
75+
}
76+
if isProviderProxyPath("/providers") {
77+
t.Fatal("did not expect regular web route to be detected")
78+
}
79+
}

internal/handler/proxy.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,14 @@ func (h *ProxyHandler) ingress(c *flow.Ctx) {
185185

186186
var providerID uint64
187187
if providerIDStr := r.Header.Get("X-Maxx-Provider-ID"); providerIDStr != "" {
188-
if pid, err := strconv.ParseUint(providerIDStr, 10, 64); err == nil {
189-
providerID = pid
190-
log.Printf("[Proxy] Using provider ID from header: %d", providerID)
188+
pid, err := strconv.ParseUint(providerIDStr, 10, 64)
189+
if err != nil || pid == 0 {
190+
writeError(w, http.StatusBadRequest, "invalid provider id")
191+
c.Abort()
192+
return
191193
}
194+
providerID = pid
195+
log.Printf("[Proxy] Using provider ID from header: %d", providerID)
192196
}
193197

194198
session, sessionErr := h.sessionRepo.GetBySessionID(tenantID, sessionID)

internal/handler/static.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -308,17 +308,19 @@ func getMimeType(filePath string) string {
308308
}
309309

310310
// NewCombinedHandler creates a handler that routes project-prefixed proxy requests
311-
// to the ProjectProxyHandler, and all other requests to the static file handler.
312-
// This allows URLs like /my-project/v1/messages to be proxied through a specific project.
313-
func NewCombinedHandler(projectProxyHandler *ProjectProxyHandler, staticHandler http.Handler) http.Handler {
311+
// to the ProjectProxyHandler, provider-prefixed proxy requests to the ProviderProxyHandler,
312+
// and all other requests to the static file handler.
313+
func NewCombinedHandler(projectProxyHandler *ProjectProxyHandler, providerProxyHandler *ProviderProxyHandler, staticHandler http.Handler) http.Handler {
314314
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
315-
// Check if this looks like a project-prefixed proxy request
316315
if isProjectProxyPath(r.URL.Path) {
316+
if isProviderProxyPath(r.URL.Path) {
317+
providerProxyHandler.ServeHTTP(w, r)
318+
return
319+
}
317320
projectProxyHandler.ServeHTTP(w, r)
318321
return
319322
}
320323

321-
// Otherwise, serve static files
322324
staticHandler.ServeHTTP(w, r)
323325
})
324326
}
@@ -328,3 +330,7 @@ func NewCombinedHandler(projectProxyHandler *ProjectProxyHandler, staticHandler
328330
func isProjectProxyPath(urlPath string) bool {
329331
return strings.HasPrefix(urlPath, "/project/") || strings.HasPrefix(urlPath, "/provider/")
330332
}
333+
334+
func isProviderProxyPath(urlPath string) bool {
335+
return strings.HasPrefix(urlPath, "/provider/")
336+
}

0 commit comments

Comments
 (0)