Skip to content

Commit e63c47c

Browse files
authored
🤖 fix(proxy): add provider-scoped proxy handler (#446)
What: - rewrite provider-scoped proxy requests to use a dedicated direct-forward path instead of the generic router/executor route selection flow - wire provider proxy construction to the route and proxy-request repositories needed for one-to-one forwarding and request recording - drop the previous provider-scoped context/router/executor changes so the diff only keeps the required handler and test wiring Why: - provider-scoped requests should forward to the specified provider one-to-one without retry or generic route reuse - keeping the logic parallel to the existing project proxy entrypoint minimizes unrelated surface area in this PR Tests: - go test ./internal/handler ./internal/router ./tests/e2e/... (pass) - go build -o maxx ./cmd/maxx (pass) - MAXX_ADMIN_PASSWORD=test123 ./maxx -addr :9880 + MAXX_E2E_BASE_URL=http://127.0.0.1:9880 MAXX_E2E_USERNAME=admin MAXX_E2E_PASSWORD=test123 pnpm --dir tests/e2e/playwright test:provider-proxy-route (pass)
1 parent dfe1c17 commit e63c47c

File tree

9 files changed

+700
-0
lines changed

9 files changed

+700
-0
lines changed

‎cmd/maxx/main.go‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ func main() {
376376
// Use already-created cached project repository for project proxy handler
377377
modelsHandler := handler.NewModelsHandler(responseModelRepo, cachedProviderRepo, cachedModelMappingRepo)
378378
projectProxyHandler := handler.NewProjectProxyHandler(proxyHandler, modelsHandler, cachedProjectRepo)
379+
providerProxyHandler := handler.NewProviderProxyHandler(proxyHandler, modelsHandler, cachedProviderRepo, cachedRouteRepo, proxyRequestRepo)
379380

380381
// Setup routes
381382
mux := http.NewServeMux()
@@ -409,6 +410,8 @@ func main() {
409410
mux.Handle("/v1/responses/", proxyHandler)
410411
// Gemini API (Google AI Studio style)
411412
mux.Handle("/v1beta/models/", proxyHandler)
413+
// Provider-scoped proxy routes
414+
mux.Handle("/provider/", providerProxyHandler)
412415

413416
// Health check
414417
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {

‎internal/core/database.go‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ type ServerComponents struct {
8989
ClaudeHandler *handler.ClaudeHandler
9090
ClaudeOAuthServer *ClaudeOAuthServer
9191
ProjectProxyHandler *handler.ProjectProxyHandler
92+
ProviderProxyHandler *handler.ProviderProxyHandler
9293
RequestTracker *RequestTracker
9394
PprofManager *PprofManager
9495
AuthMiddleware *handler.AuthMiddleware
@@ -431,6 +432,7 @@ func InitializeServerComponents(
431432
ClaudeHandler: claudeHandler,
432433
ClaudeOAuthServer: claudeOAuthServer,
433434
ProjectProxyHandler: projectProxyHandler,
435+
ProviderProxyHandler: handler.NewProviderProxyHandler(proxyHandler, modelsHandler, repos.CachedProviderRepo, repos.CachedRouteRepo, repos.ProxyRequestRepo),
434436
RequestTracker: requestTracker,
435437
PprofManager: pprofMgr,
436438
AuthMiddleware: authMiddleware,

‎internal/core/server.go‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ func (s *ManagedServer) setupRoutes() *http.ServeMux {
103103
})
104104

105105
mux.HandleFunc("/ws", components.WebSocketHub.HandleWebSocket)
106+
mux.Handle("/provider/", components.ProviderProxyHandler)
106107

107108
if s.config.ServeStatic {
108109
staticHandler := handler.NewStaticHandler()
Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
package handler
2+
3+
import (
4+
"log"
5+
"net/http"
6+
"strconv"
7+
"strings"
8+
"time"
9+
10+
provideradapter "github.com/awsl-project/maxx/internal/adapter/provider"
11+
maxxctx "github.com/awsl-project/maxx/internal/context"
12+
"github.com/awsl-project/maxx/internal/domain"
13+
"github.com/awsl-project/maxx/internal/executor"
14+
"github.com/awsl-project/maxx/internal/flow"
15+
"github.com/awsl-project/maxx/internal/repository"
16+
)
17+
18+
// ProviderProxyHandler handles provider-prefixed proxy requests like /provider/{id}/v1/messages.
19+
// Unlike the generic proxy path, provider-scoped requests are forwarded one-to-one to the
20+
// requested provider without going through the generic route selection / retry chain.
21+
type ProviderProxyHandler struct {
22+
proxyHandler *ProxyHandler
23+
modelsHandler *ModelsHandler
24+
providerRepo repository.ProviderRepository
25+
routeRepo repository.RouteRepository
26+
proxyRequestRepo repository.ProxyRequestRepository
27+
}
28+
29+
// NewProviderProxyHandler creates a new provider proxy handler.
30+
func NewProviderProxyHandler(
31+
proxyHandler *ProxyHandler,
32+
modelsHandler *ModelsHandler,
33+
providerRepo repository.ProviderRepository,
34+
routeRepo repository.RouteRepository,
35+
proxyRequestRepo repository.ProxyRequestRepository,
36+
) *ProviderProxyHandler {
37+
return &ProviderProxyHandler{
38+
proxyHandler: proxyHandler,
39+
modelsHandler: modelsHandler,
40+
providerRepo: providerRepo,
41+
routeRepo: routeRepo,
42+
proxyRequestRepo: proxyRequestRepo,
43+
}
44+
}
45+
46+
// ServeHTTP handles provider-prefixed proxy requests.
47+
func (h *ProviderProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
48+
providerID, apiPath, ok := h.parseProviderPath(r.URL.Path)
49+
if !ok {
50+
writeError(w, http.StatusNotFound, "invalid provider proxy path")
51+
return
52+
}
53+
54+
providerIDNum, err := strconv.ParseUint(providerID, 10, 64)
55+
if err != nil || providerIDNum == 0 {
56+
writeError(w, http.StatusBadRequest, "invalid provider id")
57+
return
58+
}
59+
60+
tenantID := maxxctx.GetTenantID(r.Context())
61+
provider, err := h.providerRepo.GetByID(tenantID, providerIDNum)
62+
if err != nil {
63+
log.Printf("[ProviderProxy] failed to load provider tenant=%d id=%d: %v", tenantID, providerIDNum, err)
64+
writeError(w, http.StatusInternalServerError, "internal server error")
65+
return
66+
}
67+
if provider == nil {
68+
log.Printf("[ProviderProxy] Provider not found for id: %s", providerID)
69+
writeError(w, http.StatusNotFound, "provider not found")
70+
return
71+
}
72+
73+
if apiPath == "/v1/models" {
74+
r.URL.Path = apiPath
75+
h.modelsHandler.ServeHTTP(w, r)
76+
return
77+
}
78+
79+
log.Printf("[ProviderProxy] Direct forwarding through provider: %s (ID: %d)", provider.Name, provider.ID)
80+
r.URL.Path = apiPath
81+
82+
ctx := flow.NewCtx(w, r)
83+
handlers := append([]flow.HandlerFunc{}, h.proxyHandler.extra...)
84+
handlers = append(handlers, h.directDispatch(provider))
85+
h.proxyHandler.engine.HandleWith(ctx, handlers...)
86+
}
87+
88+
func (h *ProviderProxyHandler) directDispatch(provider *domain.Provider) flow.HandlerFunc {
89+
return func(c *flow.Ctx) {
90+
tenantID := maxxctx.GetTenantID(c.Request.Context())
91+
clientType := flow.GetClientType(c)
92+
if clientType == "" {
93+
writeError(c.Writer, http.StatusBadRequest, "unable to determine client type")
94+
c.Abort()
95+
return
96+
}
97+
98+
route, err := h.routeRepo.FindByKey(tenantID, 0, provider.ID, clientType)
99+
if err != nil || route == nil {
100+
log.Printf("[ProviderProxy] route not found tenant=%d provider=%d clientType=%s: %v", tenantID, provider.ID, clientType, err)
101+
writeError(c.Writer, http.StatusNotFound, "provider route not found")
102+
c.Abort()
103+
return
104+
}
105+
106+
factory, ok := provideradapter.GetAdapterFactory(provider.Type)
107+
if !ok {
108+
writeError(c.Writer, http.StatusBadGateway, "provider adapter not found")
109+
c.Abort()
110+
return
111+
}
112+
adapter, err := factory(provider)
113+
if err != nil {
114+
log.Printf("[ProviderProxy] failed to create adapter provider=%d type=%s: %v", provider.ID, provider.Type, err)
115+
writeError(c.Writer, http.StatusBadGateway, "provider adapter init failed")
116+
c.Abort()
117+
return
118+
}
119+
if !providerSupportsClientType(adapter.SupportedClientTypes(), clientType) {
120+
writeError(c.Writer, http.StatusBadRequest, "provider does not support this client type")
121+
c.Abort()
122+
return
123+
}
124+
125+
requestModel := flow.GetRequestModel(c)
126+
mappedModel := requestModel
127+
isStream := flow.GetIsStream(c)
128+
proxyReq := h.newProxyRequest(c, route, provider, requestModel, mappedModel, isStream)
129+
if err := h.proxyRequestRepo.Create(proxyReq); err != nil {
130+
log.Printf("[ProviderProxy] failed to create proxy request: %v", err)
131+
}
132+
133+
c.Set(flow.KeyMappedModel, mappedModel)
134+
c.Set(flow.KeyOriginalClientType, clientType)
135+
c.Set(flow.KeyProxyRequest, proxyReq)
136+
137+
responseCapture := executor.NewResponseCapture(c.Writer)
138+
originalWriter := c.Writer
139+
c.Writer = responseCapture
140+
err = adapter.Execute(c, provider)
141+
c.Writer = originalWriter
142+
143+
now := time.Now()
144+
proxyReq.EndTime = now
145+
proxyReq.Duration = now.Sub(proxyReq.StartTime)
146+
proxyReq.StatusCode = responseCapture.StatusCode()
147+
proxyReq.ResponseModel = mappedModel
148+
proxyReq.ResponseInfo = &domain.ResponseInfo{
149+
Status: responseCapture.StatusCode(),
150+
Headers: responseCapture.CapturedHeaders(),
151+
Body: responseCapture.Body(),
152+
}
153+
154+
if err == nil {
155+
proxyReq.Status = "COMPLETED"
156+
_ = h.proxyRequestRepo.Update(proxyReq)
157+
return
158+
}
159+
160+
proxyReq.Status = "FAILED"
161+
proxyReq.Error = err.Error()
162+
if proxyErr, ok := err.(*domain.ProxyError); ok {
163+
if isStream {
164+
writeStreamError(responseCapture, proxyErr)
165+
} else {
166+
writeProxyError(responseCapture, proxyErr)
167+
}
168+
if proxyErr.HTTPStatusCode >= 400 && proxyErr.HTTPStatusCode < 600 {
169+
proxyReq.StatusCode = proxyErr.HTTPStatusCode
170+
}
171+
} else {
172+
writeError(responseCapture, http.StatusBadGateway, err.Error())
173+
proxyReq.StatusCode = http.StatusBadGateway
174+
}
175+
_ = h.proxyRequestRepo.Update(proxyReq)
176+
c.Abort()
177+
}
178+
}
179+
180+
func (h *ProviderProxyHandler) newProxyRequest(c *flow.Ctx, route *domain.Route, provider *domain.Provider, requestModel, mappedModel string, isStream bool) *domain.ProxyRequest {
181+
requestHeaders := flow.GetRequestHeaders(c)
182+
requestURI := flow.GetRequestURI(c)
183+
requestBody := flow.GetRequestBody(c)
184+
apiTokenID := flow.GetAPITokenID(c)
185+
projectID := flow.GetProjectID(c)
186+
tenantID := maxxctx.GetTenantID(c.Request.Context())
187+
devMode := false
188+
if v, ok := c.Get(flow.KeyAPITokenDevMode); ok {
189+
if b, ok := v.(bool); ok {
190+
devMode = b
191+
}
192+
}
193+
194+
return &domain.ProxyRequest{
195+
TenantID: tenantID,
196+
RequestID: generateProxyRequestID(),
197+
SessionID: flow.GetSessionID(c),
198+
ClientType: flow.GetClientType(c),
199+
RequestModel: requestModel,
200+
ResponseModel: mappedModel,
201+
StartTime: time.Now(),
202+
IsStream: isStream,
203+
Status: "IN_PROGRESS",
204+
StatusCode: http.StatusOK,
205+
RequestInfo: &domain.RequestInfo{
206+
Method: c.Request.Method,
207+
Headers: flattenRequestHeaders(requestHeaders),
208+
URL: requestURI,
209+
Body: string(requestBody),
210+
},
211+
RouteID: route.ID,
212+
ProviderID: provider.ID,
213+
ProjectID: projectID,
214+
APITokenID: apiTokenID,
215+
DevMode: devMode,
216+
}
217+
}
218+
219+
func generateProxyRequestID() string {
220+
return time.Now().Format("20060102150405.000000")
221+
}
222+
223+
func flattenRequestHeaders(h http.Header) map[string]string {
224+
if h == nil {
225+
return nil
226+
}
227+
result := make(map[string]string)
228+
for key, values := range h {
229+
if len(values) > 0 {
230+
result[key] = values[0]
231+
}
232+
}
233+
return result
234+
}
235+
236+
func providerSupportsClientType(supported []domain.ClientType, clientType domain.ClientType) bool {
237+
for _, ct := range supported {
238+
if ct == clientType {
239+
return true
240+
}
241+
}
242+
return false
243+
}
244+
245+
// parseProviderPath extracts the provider ID and API path from a provider-prefixed URL.
246+
func (h *ProviderProxyHandler) parseProviderPath(path string) (providerID, apiPath string, ok bool) {
247+
if !strings.HasPrefix(path, "/provider/") {
248+
return "", "", false
249+
}
250+
251+
path = strings.TrimPrefix(path, "/provider/")
252+
parts := strings.SplitN(path, "/", 2)
253+
if len(parts) < 2 {
254+
return "", "", false
255+
}
256+
257+
providerID = strings.TrimSpace(parts[0])
258+
if providerID == "" {
259+
return "", "", false
260+
}
261+
262+
apiPath = "/" + parts[1]
263+
if !isValidProviderAPIPath(apiPath) {
264+
return "", "", false
265+
}
266+
267+
return providerID, apiPath, true
268+
}
269+
270+
func isProviderProxyPath(urlPath string) bool {
271+
return strings.HasPrefix(urlPath, "/provider/")
272+
}
273+
274+
func isValidProviderAPIPath(path string) bool {
275+
if path == "/v1/messages" || strings.HasPrefix(path, "/v1/messages/") {
276+
return true
277+
}
278+
if path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/chat/completions/") {
279+
return true
280+
}
281+
if path == "/responses" || strings.HasPrefix(path, "/responses/") {
282+
return true
283+
}
284+
if path == "/v1/responses" || strings.HasPrefix(path, "/v1/responses/") {
285+
return true
286+
}
287+
if path == "/v1/models" || strings.HasPrefix(path, "/v1/models/") {
288+
return true
289+
}
290+
if path == "/v1beta/models" || strings.HasPrefix(path, "/v1beta/models/") {
291+
return true
292+
}
293+
return false
294+
}

0 commit comments

Comments
 (0)