Skip to content

Commit 3839d93

Browse files
committed
feat: add websocket routing and executor unregister API
- Introduce Server.AttachWebsocketRoute(path, handler) to mount websocket upgrade handlers on the Gin engine. - Track registered WS paths via wsRoutes with wsRouteMu to prevent duplicate registrations; initialize in NewServer and import sync. - Add Manager.UnregisterExecutor(provider) for clean executor lifecycle management. - Add github.com/gorilla/websocket v1.5.3 dependency and update go.sum. Motivation: enable services to expose WS endpoints through the core server and allow removing auth executors dynamically while avoiding duplicate route setup. No breaking changes.
1 parent a552a45 commit 3839d93

File tree

11 files changed

+1035
-3
lines changed

11 files changed

+1035
-3
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ require (
77
github.com/gin-gonic/gin v1.10.1
88
github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145
99
github.com/google/uuid v1.6.0
10+
github.com/gorilla/websocket v1.5.3
1011
github.com/jackc/pgx/v5 v5.7.6
1112
github.com/joho/godotenv v1.5.1
1213
github.com/klauspost/compress v1.17.4

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
6666
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
6767
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
6868
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
69+
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
70+
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
6971
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
7072
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
7173
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@@ -80,8 +82,6 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr
8082
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
8183
github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PWkxoFkQ=
8284
github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
83-
github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA=
84-
github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
8585
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
8686
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
8787
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=

internal/api/server.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"os"
1414
"path/filepath"
1515
"strings"
16+
"sync"
1617
"sync/atomic"
1718
"time"
1819

@@ -138,6 +139,10 @@ type Server struct {
138139
// currentPath is the absolute path to the current working directory.
139140
currentPath string
140141

142+
// wsRoutes tracks registered websocket upgrade paths.
143+
wsRouteMu sync.Mutex
144+
wsRoutes map[string]struct{}
145+
141146
// management handler
142147
mgmt *managementHandlers.Handler
143148

@@ -228,6 +233,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
228233
configFilePath: configFilePath,
229234
currentPath: wd,
230235
envManagementSecret: envManagementSecret,
236+
wsRoutes: make(map[string]struct{}),
231237
}
232238
// Save initial YAML snapshot
233239
s.oldConfigYaml, _ = yaml.Marshal(cfg)
@@ -371,6 +377,33 @@ func (s *Server) setupRoutes() {
371377
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
372378
}
373379

380+
// AttachWebsocketRoute registers a websocket upgrade handler on the primary Gin engine.
381+
// The handler is served as-is without additional middleware beyond the standard stack already configured.
382+
func (s *Server) AttachWebsocketRoute(path string, handler http.Handler) {
383+
if s == nil || s.engine == nil || handler == nil {
384+
return
385+
}
386+
trimmed := strings.TrimSpace(path)
387+
if trimmed == "" {
388+
trimmed = "/v1/ws"
389+
}
390+
if !strings.HasPrefix(trimmed, "/") {
391+
trimmed = "/" + trimmed
392+
}
393+
s.wsRouteMu.Lock()
394+
if _, exists := s.wsRoutes[trimmed]; exists {
395+
s.wsRouteMu.Unlock()
396+
return
397+
}
398+
s.wsRoutes[trimmed] = struct{}{}
399+
s.wsRouteMu.Unlock()
400+
401+
s.engine.GET(trimmed, func(c *gin.Context) {
402+
handler.ServeHTTP(c.Writer, c.Request)
403+
c.Abort()
404+
})
405+
}
406+
374407
func (s *Server) registerManagementRoutes() {
375408
if s == nil || s.engine == nil || s.mgmt == nil {
376409
return
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
package executor
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"fmt"
7+
"net/http"
8+
"net/url"
9+
"strings"
10+
11+
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
12+
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
13+
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
14+
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
15+
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
16+
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
17+
"github.com/tidwall/sjson"
18+
)
19+
20+
// AistudioExecutor routes AI Studio requests through a websocket-backed transport.
21+
type AistudioExecutor struct {
22+
provider string
23+
relay *wsrelay.Manager
24+
cfg *config.Config
25+
}
26+
27+
// NewAistudioExecutor constructs a websocket executor for the provider name.
28+
func NewAistudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Manager) *AistudioExecutor {
29+
return &AistudioExecutor{provider: strings.ToLower(provider), relay: relay, cfg: cfg}
30+
}
31+
32+
// Identifier returns the provider key served by this executor.
33+
func (e *AistudioExecutor) Identifier() string { return e.provider }
34+
35+
// PrepareRequest is a no-op because websocket transport already injects headers.
36+
func (e *AistudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
37+
return nil
38+
}
39+
40+
func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
41+
translatedReq, body, err := e.translateRequest(req, opts, false)
42+
if err != nil {
43+
return cliproxyexecutor.Response{}, err
44+
}
45+
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
46+
wsReq := &wsrelay.HTTPRequest{
47+
Method: http.MethodPost,
48+
URL: endpoint,
49+
Headers: http.Header{"Content-Type": []string{"application/json"}},
50+
Body: body.payload,
51+
}
52+
53+
var authID, authLabel, authType, authValue string
54+
if auth != nil {
55+
authID = auth.ID
56+
authLabel = auth.Label
57+
authType, authValue = auth.AccountInfo()
58+
}
59+
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
60+
URL: endpoint,
61+
Method: http.MethodPost,
62+
Headers: wsReq.Headers.Clone(),
63+
Body: bytes.Clone(body.payload),
64+
Provider: e.provider,
65+
AuthID: authID,
66+
AuthLabel: authLabel,
67+
AuthType: authType,
68+
AuthValue: authValue,
69+
})
70+
71+
resp, err := e.relay.RoundTrip(ctx, e.provider, wsReq)
72+
if err != nil {
73+
recordAPIResponseError(ctx, e.cfg, err)
74+
return cliproxyexecutor.Response{}, err
75+
}
76+
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
77+
if len(resp.Body) > 0 {
78+
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body))
79+
}
80+
if resp.Status < 200 || resp.Status >= 300 {
81+
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
82+
}
83+
var param any
84+
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(resp.Body), &param)
85+
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
86+
}
87+
88+
func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
89+
translatedReq, body, err := e.translateRequest(req, opts, true)
90+
if err != nil {
91+
return nil, err
92+
}
93+
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
94+
wsReq := &wsrelay.HTTPRequest{
95+
Method: http.MethodPost,
96+
URL: endpoint,
97+
Headers: http.Header{"Content-Type": []string{"application/json"}},
98+
Body: body.payload,
99+
}
100+
var authID, authLabel, authType, authValue string
101+
if auth != nil {
102+
authID = auth.ID
103+
authLabel = auth.Label
104+
authType, authValue = auth.AccountInfo()
105+
}
106+
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
107+
URL: endpoint,
108+
Method: http.MethodPost,
109+
Headers: wsReq.Headers.Clone(),
110+
Body: bytes.Clone(body.payload),
111+
Provider: e.provider,
112+
AuthID: authID,
113+
AuthLabel: authLabel,
114+
AuthType: authType,
115+
AuthValue: authValue,
116+
})
117+
stream, err := e.relay.Stream(ctx, e.provider, wsReq)
118+
if err != nil {
119+
recordAPIResponseError(ctx, e.cfg, err)
120+
return nil, err
121+
}
122+
out := make(chan cliproxyexecutor.StreamChunk)
123+
go func() {
124+
defer close(out)
125+
var param any
126+
metadataLogged := false
127+
for event := range stream {
128+
if event.Err != nil {
129+
recordAPIResponseError(ctx, e.cfg, event.Err)
130+
out <- cliproxyexecutor.StreamChunk{Err: event.Err}
131+
return
132+
}
133+
switch event.Type {
134+
case wsrelay.MessageTypeStreamStart:
135+
if !metadataLogged && event.Status > 0 {
136+
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
137+
metadataLogged = true
138+
}
139+
case wsrelay.MessageTypeStreamChunk:
140+
if len(event.Payload) > 0 {
141+
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
142+
}
143+
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), &param)
144+
for i := range lines {
145+
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
146+
}
147+
case wsrelay.MessageTypeStreamEnd:
148+
return
149+
case wsrelay.MessageTypeHTTPResp:
150+
if !metadataLogged && event.Status > 0 {
151+
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
152+
metadataLogged = true
153+
}
154+
if len(event.Payload) > 0 {
155+
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
156+
}
157+
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), &param)
158+
for i := range lines {
159+
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
160+
}
161+
return
162+
case wsrelay.MessageTypeError:
163+
recordAPIResponseError(ctx, e.cfg, event.Err)
164+
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
165+
return
166+
}
167+
}
168+
}()
169+
return out, nil
170+
}
171+
172+
func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
173+
translatedReq, body, err := e.translateRequest(req, opts, false)
174+
if err != nil {
175+
return cliproxyexecutor.Response{}, err
176+
}
177+
endpoint := e.buildEndpoint(req.Model, "countTokens", "")
178+
wsReq := &wsrelay.HTTPRequest{
179+
Method: http.MethodPost,
180+
URL: endpoint,
181+
Headers: http.Header{"Content-Type": []string{"application/json"}},
182+
Body: body.payload,
183+
}
184+
var authID, authLabel, authType, authValue string
185+
if auth != nil {
186+
authID = auth.ID
187+
authLabel = auth.Label
188+
authType, authValue = auth.AccountInfo()
189+
}
190+
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
191+
URL: endpoint,
192+
Method: http.MethodPost,
193+
Headers: wsReq.Headers.Clone(),
194+
Body: bytes.Clone(body.payload),
195+
Provider: e.provider,
196+
AuthID: authID,
197+
AuthLabel: authLabel,
198+
AuthType: authType,
199+
AuthValue: authValue,
200+
})
201+
resp, err := e.relay.RoundTrip(ctx, e.provider, wsReq)
202+
if err != nil {
203+
recordAPIResponseError(ctx, e.cfg, err)
204+
return cliproxyexecutor.Response{}, err
205+
}
206+
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
207+
if len(resp.Body) > 0 {
208+
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body))
209+
}
210+
if resp.Status < 200 || resp.Status >= 300 {
211+
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
212+
}
213+
var param any
214+
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(resp.Body), &param)
215+
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
216+
}
217+
218+
func (e *AistudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
219+
_ = ctx
220+
return auth, nil
221+
}
222+
223+
type translatedPayload struct {
224+
payload []byte
225+
action string
226+
toFormat sdktranslator.Format
227+
}
228+
229+
func (e *AistudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) {
230+
from := opts.SourceFormat
231+
to := sdktranslator.FromString("gemini")
232+
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
233+
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok {
234+
payload = util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride)
235+
}
236+
payload = disableGeminiThinkingConfig(payload, req.Model)
237+
payload = fixGeminiImageAspectRatio(req.Model, payload)
238+
metadataAction := "generateContent"
239+
if req.Metadata != nil {
240+
if action, _ := req.Metadata["action"].(string); action == "countTokens" {
241+
metadataAction = action
242+
}
243+
}
244+
action := metadataAction
245+
if stream && action != "countTokens" {
246+
action = "streamGenerateContent"
247+
}
248+
payload, _ = sjson.DeleteBytes(payload, "session_id")
249+
return payload, translatedPayload{payload: payload, action: action, toFormat: to}, nil
250+
}
251+
252+
func (e *AistudioExecutor) buildEndpoint(model, action, alt string) string {
253+
base := fmt.Sprintf("%s/%s/models/%s:%s", glEndpoint, glAPIVersion, model, action)
254+
if action == "streamGenerateContent" {
255+
if alt == "" {
256+
return base + "?alt=sse"
257+
}
258+
return base + "?$alt=" + url.QueryEscape(alt)
259+
}
260+
if alt != "" && action != "countTokens" {
261+
return base + "?$alt=" + url.QueryEscape(alt)
262+
}
263+
return base
264+
}

0 commit comments

Comments
 (0)