Skip to content

Commit 0a97489

Browse files
committed
🤖 fix(proxy): address provider-scoped routing review issues
What: - add a dedicated ProviderProxyHandler for /provider/{id}/... requests while keeping the existing scoped project handler intact - wire provider-scoped requests through separate server/static dispatch paths instead of routing them through the project handler entrypoint - keep the provider-scope safety fixes for path validation, invalid provider header rejection, and router matching - add regression coverage for provider handler dispatch, provider-scoped router behavior, and the Playwright provider route flow Why: - make provider scope an independent proxy surface without rewriting the existing project-scoped implementation more than necessary - reduce project-related churn in this PR while still fixing the provider-scoped control-flow problems caught in review - leave the codepath easier to review by isolating the new behavior to provider-specific wiring plus the minimal shared safeguards Tests: - go test ./internal/handler ./internal/router ./tests/e2e/... (pass)
1 parent 324bae9 commit 0a97489

File tree

13 files changed

+426
-99
lines changed

13 files changed

+426
-99
lines changed

cmd/maxx/main.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ func main() {
376376
// Use already-created cached project repository for project proxy handler
377377
modelsHandler := handler.NewModelsHandler(responseModelRepo, cachedProviderRepo, cachedModelMappingRepo)
378378
projectProxyHandler := handler.NewProjectProxyHandler(proxyHandler, modelsHandler, cachedProjectRepo, cachedProviderRepo)
379+
providerProxyHandler := handler.NewProviderProxyHandler(proxyHandler, modelsHandler, cachedProviderRepo)
379380

380381
// Setup routes
381382
mux := http.NewServeMux()
@@ -422,7 +423,7 @@ func main() {
422423

423424
// Serve static files (Web UI) with project proxy support - must be last (default route)
424425
staticHandler := handler.NewStaticHandler()
425-
combinedHandler := handler.NewCombinedHandler(projectProxyHandler, staticHandler)
426+
combinedHandler := handler.NewCombinedHandler(projectProxyHandler, providerProxyHandler, staticHandler)
426427
mux.Handle("/", combinedHandler)
427428

428429
// Wrap with logging middleware
@@ -531,6 +532,7 @@ func main() {
531532
log.Printf(" Codex: http://localhost%s/v1/responses", *addr)
532533
log.Printf(" Gemini: http://localhost%s/v1beta/models/{model}:generateContent", *addr)
533534
log.Printf("Project proxy: http://localhost%s/project/{project-slug}/v1/messages (etc.)", *addr)
535+
log.Printf("Provider proxy: http://localhost%s/provider/{provider-id}/v1/messages (etc.)", *addr)
534536

535537
go func() {
536538
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {

internal/core/database.go

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -73,27 +73,28 @@ type DatabaseRepos struct {
7373

7474
// ServerComponents 包含服务器运行所需的所有组件
7575
type ServerComponents struct {
76-
Router *router.Router
77-
WebSocketHub *handler.WebSocketHub
78-
WailsBroadcaster *event.WailsBroadcaster
79-
Executor *executor.Executor
80-
ClientAdapter *client.Adapter
81-
AdminService *service.AdminService
82-
ProxyHandler *handler.ProxyHandler
83-
ModelsHandler *handler.ModelsHandler
84-
AdminHandler *handler.AdminHandler
85-
AntigravityHandler *handler.AntigravityHandler
86-
KiroHandler *handler.KiroHandler
87-
CodexHandler *handler.CodexHandler
88-
CodexOAuthServer *CodexOAuthServer
89-
ClaudeHandler *handler.ClaudeHandler
90-
ClaudeOAuthServer *ClaudeOAuthServer
91-
ProjectProxyHandler *handler.ProjectProxyHandler
92-
RequestTracker *RequestTracker
93-
PprofManager *PprofManager
94-
AuthMiddleware *handler.AuthMiddleware
95-
AuthHandler *handler.AuthHandler
96-
BackupService *service.BackupService
76+
Router *router.Router
77+
WebSocketHub *handler.WebSocketHub
78+
WailsBroadcaster *event.WailsBroadcaster
79+
Executor *executor.Executor
80+
ClientAdapter *client.Adapter
81+
AdminService *service.AdminService
82+
ProxyHandler *handler.ProxyHandler
83+
ModelsHandler *handler.ModelsHandler
84+
AdminHandler *handler.AdminHandler
85+
AntigravityHandler *handler.AntigravityHandler
86+
KiroHandler *handler.KiroHandler
87+
CodexHandler *handler.CodexHandler
88+
CodexOAuthServer *CodexOAuthServer
89+
ClaudeHandler *handler.ClaudeHandler
90+
ClaudeOAuthServer *ClaudeOAuthServer
91+
ProjectProxyHandler *handler.ProjectProxyHandler
92+
ProviderProxyHandler *handler.ProviderProxyHandler
93+
RequestTracker *RequestTracker
94+
PprofManager *PprofManager
95+
AuthMiddleware *handler.AuthMiddleware
96+
AuthHandler *handler.AuthHandler
97+
BackupService *service.BackupService
9798
}
9899

99100
// InitializeDatabase 初始化数据库和所有仓库
@@ -409,33 +410,35 @@ func InitializeServerComponents(
409410
claudeOAuthServer := NewClaudeOAuthServer(claudeHandler)
410411
claudeHandler.SetOAuthServer(claudeOAuthServer)
411412
projectProxyHandler := handler.NewProjectProxyHandler(proxyHandler, modelsHandler, repos.CachedProjectRepo, repos.CachedProviderRepo)
413+
providerProxyHandler := handler.NewProviderProxyHandler(proxyHandler, modelsHandler, repos.CachedProviderRepo)
412414

413415
log.Printf("[Core] Creating request tracker for graceful shutdown")
414416
requestTracker := NewRequestTracker()
415417
proxyHandler.SetRequestTracker(requestTracker)
416418

417419
components := &ServerComponents{
418-
Router: r,
419-
WebSocketHub: wsHub,
420-
WailsBroadcaster: wailsBroadcaster,
421-
Executor: exec,
422-
ClientAdapter: clientAdapter,
423-
AdminService: adminService,
424-
ProxyHandler: proxyHandler,
425-
ModelsHandler: modelsHandler,
426-
AdminHandler: adminHandler,
427-
AntigravityHandler: antigravityHandler,
428-
KiroHandler: kiroHandler,
429-
CodexHandler: codexHandler,
430-
CodexOAuthServer: codexOAuthServer,
431-
ClaudeHandler: claudeHandler,
432-
ClaudeOAuthServer: claudeOAuthServer,
433-
ProjectProxyHandler: projectProxyHandler,
434-
RequestTracker: requestTracker,
435-
PprofManager: pprofMgr,
436-
AuthMiddleware: authMiddleware,
437-
AuthHandler: authHandler,
438-
BackupService: backupService,
420+
Router: r,
421+
WebSocketHub: wsHub,
422+
WailsBroadcaster: wailsBroadcaster,
423+
Executor: exec,
424+
ClientAdapter: clientAdapter,
425+
AdminService: adminService,
426+
ProxyHandler: proxyHandler,
427+
ModelsHandler: modelsHandler,
428+
AdminHandler: adminHandler,
429+
AntigravityHandler: antigravityHandler,
430+
KiroHandler: kiroHandler,
431+
CodexHandler: codexHandler,
432+
CodexOAuthServer: codexOAuthServer,
433+
ClaudeHandler: claudeHandler,
434+
ClaudeOAuthServer: claudeOAuthServer,
435+
ProjectProxyHandler: projectProxyHandler,
436+
ProviderProxyHandler: providerProxyHandler,
437+
RequestTracker: requestTracker,
438+
PprofManager: pprofMgr,
439+
AuthMiddleware: authMiddleware,
440+
AuthHandler: authHandler,
441+
BackupService: backupService,
439442
}
440443

441444
log.Printf("[Core] Server components initialized successfully")

internal/core/server.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,12 @@ func (s *ManagedServer) setupRoutes() *http.ServeMux {
106106

107107
if s.config.ServeStatic {
108108
staticHandler := handler.NewStaticHandler()
109-
combinedHandler := handler.NewCombinedHandler(components.ProjectProxyHandler, staticHandler)
109+
combinedHandler := handler.NewCombinedHandler(components.ProjectProxyHandler, components.ProviderProxyHandler, staticHandler)
110110
mux.Handle("/", combinedHandler)
111111
log.Printf("[Server] Static file serving enabled")
112112
} else {
113-
mux.Handle("/", components.ProjectProxyHandler)
113+
mux.Handle("/project/", components.ProjectProxyHandler)
114+
mux.Handle("/provider/", components.ProviderProxyHandler)
114115
log.Printf("[Server] Static file serving disabled (Wails mode)")
115116
}
116117

internal/handler/project_proxy.go

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,16 @@ func (h *ProjectProxyHandler) parseScopedPath(path string) (scopeType, scopeValu
9595

9696
func parseScopePath(scopeType, path string) (resolvedType, scopeValue, apiPath string, ok bool) {
9797
parts := strings.SplitN(path, "/", 2)
98-
if len(parts) < 2 || strings.TrimSpace(parts[0]) == "" {
98+
if len(parts) < 2 {
9999
return "", "", "", false
100100
}
101101

102-
scopeValue = parts[0]
102+
trimmed := strings.TrimSpace(parts[0])
103+
if trimmed == "" {
104+
return "", "", "", false
105+
}
106+
107+
scopeValue = trimmed
103108
apiPath = "/" + parts[1]
104109
if !isValidAPIPath(apiPath) {
105110
return "", "", "", false
@@ -109,25 +114,16 @@ func parseScopePath(scopeType, path string) (resolvedType, scopeValue, apiPath s
109114

110115
// isValidAPIPath checks if the path is a known proxy API endpoint.
111116
func isValidAPIPath(path string) bool {
112-
if strings.HasPrefix(path, "/v1/messages") {
113-
return true
114-
}
115-
if strings.HasPrefix(path, "/v1/chat/completions") {
116-
return true
117-
}
118-
if strings.HasPrefix(path, "/responses") {
119-
return true
120-
}
121-
if strings.HasPrefix(path, "/v1/responses") {
122-
return true
123-
}
124-
if strings.HasPrefix(path, "/v1/models") {
125-
return true
126-
}
127-
if strings.HasPrefix(path, "/v1beta/models/") {
128-
return true
129-
}
130-
return false
117+
return matchesEndpointPath(path, "/v1/messages") ||
118+
matchesEndpointPath(path, "/v1/chat/completions") ||
119+
matchesEndpointPath(path, "/responses") ||
120+
matchesEndpointPath(path, "/v1/responses") ||
121+
matchesEndpointPath(path, "/v1/models") ||
122+
matchesEndpointPath(path, "/v1beta/models")
123+
}
124+
125+
func matchesEndpointPath(path, endpoint string) bool {
126+
return path == endpoint || strings.HasPrefix(path, endpoint+"/")
131127
}
132128

133129
func itoa(n uint64) string {

internal/handler/project_proxy_test.go

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,62 @@ func TestIsProjectProxyPath(t *testing.T) {
4040
if !isProjectProxyPath("/project/demo/v1/messages") {
4141
t.Fatal("expected project path to be detected")
4242
}
43-
if !isProjectProxyPath("/provider/1/v1/messages") {
44-
t.Fatal("expected provider path to be detected")
43+
if isProjectProxyPath("/provider/1/v1/messages") {
44+
t.Fatal("did not expect provider path to be detected as project path")
4545
}
4646
if isProjectProxyPath("/projects") {
4747
t.Fatal("did not expect regular web route to be detected")
4848
}
4949
}
50+
51+
func TestParseScopePath_TrimsScopeValue(t *testing.T) {
52+
scopeType, scopeValue, apiPath, ok := parseScopePath("project", " demo /v1/messages")
53+
if !ok {
54+
t.Fatal("expected scoped path to parse")
55+
}
56+
if scopeType != "project" {
57+
t.Fatalf("scopeType = %q, want project", scopeType)
58+
}
59+
if scopeValue != "demo" {
60+
t.Fatalf("scopeValue = %q, want demo", scopeValue)
61+
}
62+
if apiPath != "/v1/messages" {
63+
t.Fatalf("apiPath = %q, want /v1/messages", apiPath)
64+
}
65+
}
66+
67+
func TestIsValidAPIPath_AllowsExactAndSubpathsOnly(t *testing.T) {
68+
valid := []string{
69+
"/v1/messages",
70+
"/v1/messages/stream",
71+
"/v1/chat/completions",
72+
"/v1/chat/completions/extra",
73+
"/responses",
74+
"/responses/items",
75+
"/v1/responses",
76+
"/v1/responses/abc",
77+
"/v1/models",
78+
"/v1/models/list",
79+
"/v1beta/models",
80+
"/v1beta/models/gemini-2.5-pro",
81+
}
82+
for _, path := range valid {
83+
if !isValidAPIPath(path) {
84+
t.Fatalf("expected %q to be valid", path)
85+
}
86+
}
87+
88+
invalid := []string{
89+
"/v1/messages-debug",
90+
"/v1/chat/completionsXYZ",
91+
"/responses123",
92+
"/v1/responsesXYZ",
93+
"/v1/models-debug",
94+
"/v1beta/modelsX",
95+
}
96+
for _, path := range invalid {
97+
if isValidAPIPath(path) {
98+
t.Fatalf("expected %q to be invalid", path)
99+
}
100+
}
101+
}

internal/handler/provider_proxy.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package handler
2+
3+
import (
4+
"log"
5+
"net/http"
6+
"strconv"
7+
"strings"
8+
9+
maxxctx "github.com/awsl-project/maxx/internal/context"
10+
"github.com/awsl-project/maxx/internal/repository"
11+
)
12+
13+
// ProviderProxyHandler wraps ProxyHandler to handle provider-scoped proxy requests
14+
// like /provider/{id}/v1/messages.
15+
type ProviderProxyHandler struct {
16+
proxyHandler *ProxyHandler
17+
modelsHandler *ModelsHandler
18+
providerRepo repository.ProviderRepository
19+
}
20+
21+
// NewProviderProxyHandler creates a new provider-scoped proxy handler.
22+
func NewProviderProxyHandler(
23+
proxyHandler *ProxyHandler,
24+
modelsHandler *ModelsHandler,
25+
providerRepo repository.ProviderRepository,
26+
) *ProviderProxyHandler {
27+
return &ProviderProxyHandler{
28+
proxyHandler: proxyHandler,
29+
modelsHandler: modelsHandler,
30+
providerRepo: providerRepo,
31+
}
32+
}
33+
34+
// ServeHTTP handles provider-scoped proxy requests.
35+
func (h *ProviderProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
36+
providerValue, apiPath, ok := h.parseScopedPath(r.URL.Path)
37+
if !ok {
38+
writeError(w, http.StatusNotFound, "invalid provider proxy path")
39+
return
40+
}
41+
42+
providerID, err := strconv.ParseUint(providerValue, 10, 64)
43+
if err != nil || providerID == 0 {
44+
writeError(w, http.StatusBadRequest, "invalid provider id")
45+
return
46+
}
47+
48+
tenantID := maxxctx.GetTenantID(r.Context())
49+
provider, err := h.providerRepo.GetByID(tenantID, providerID)
50+
if err != nil || provider == nil {
51+
log.Printf("[ProviderProxy] Provider not found for id: %s", providerValue)
52+
writeError(w, http.StatusNotFound, "provider not found")
53+
return
54+
}
55+
56+
log.Printf("[ProviderProxy] Routing request through provider: %s (ID: %d)", provider.Name, provider.ID)
57+
r.Header.Set("X-Maxx-Provider-ID", strings.TrimSpace(itoa(provider.ID)))
58+
r.URL.Path = apiPath
59+
if apiPath == "/v1/models" {
60+
h.modelsHandler.ServeHTTP(w, r)
61+
return
62+
}
63+
64+
h.proxyHandler.ServeHTTP(w, r)
65+
}
66+
67+
// parseScopedPath extracts provider ID and API path.
68+
// Input: /provider/1/v1/messages
69+
func (h *ProviderProxyHandler) parseScopedPath(path string) (providerID, apiPath string, ok bool) {
70+
if !strings.HasPrefix(path, "/provider/") {
71+
return "", "", false
72+
}
73+
_, providerID, apiPath, ok = parseScopePath("provider", strings.TrimPrefix(path, "/provider/"))
74+
return providerID, apiPath, ok
75+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package handler
2+
3+
import "testing"
4+
5+
func TestParseProviderScopedPath(t *testing.T) {
6+
h := &ProviderProxyHandler{}
7+
providerID, apiPath, ok := h.parseScopedPath("/provider/1/v1/chat/completions")
8+
if !ok {
9+
t.Fatal("expected provider path to parse")
10+
}
11+
if providerID != "1" {
12+
t.Fatalf("providerID = %q, want 1", providerID)
13+
}
14+
if apiPath != "/v1/chat/completions" {
15+
t.Fatalf("apiPath = %q, want /v1/chat/completions", apiPath)
16+
}
17+
}
18+
19+
func TestIsProviderProxyPath(t *testing.T) {
20+
if !isProviderProxyPath("/provider/1/v1/messages") {
21+
t.Fatal("expected provider path to be detected")
22+
}
23+
if isProviderProxyPath("/project/demo/v1/messages") {
24+
t.Fatal("did not expect project path to be detected as provider path")
25+
}
26+
if isProviderProxyPath("/providers") {
27+
t.Fatal("did not expect regular web route to be detected")
28+
}
29+
}

internal/handler/proxy.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,14 @@ func (h *ProxyHandler) ingress(c *flow.Ctx) {
185185

186186
var providerID uint64
187187
if providerIDStr := r.Header.Get("X-Maxx-Provider-ID"); providerIDStr != "" {
188-
if pid, err := strconv.ParseUint(providerIDStr, 10, 64); err == nil {
189-
providerID = pid
190-
log.Printf("[Proxy] Using provider ID from header: %d", providerID)
188+
pid, err := strconv.ParseUint(providerIDStr, 10, 64)
189+
if err != nil || pid == 0 {
190+
writeError(w, http.StatusBadRequest, "invalid provider id")
191+
c.Abort()
192+
return
191193
}
194+
providerID = pid
195+
log.Printf("[Proxy] Using provider ID from header: %d", providerID)
192196
}
193197

194198
session, sessionErr := h.sessionRepo.GetBySessionID(tenantID, sessionID)

0 commit comments

Comments
 (0)