Skip to content

Commit 359b8de

Browse files
committed
feat(ws): add WebSocket auth
1 parent ea6065f commit 359b8de

File tree

9 files changed

+119
-14
lines changed

9 files changed

+119
-14
lines changed

config.example.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ quota-exceeded:
4343
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
4444
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
4545

46+
# When true, enable authentication for the WebSocket API (/v1/ws).
47+
ws-auth: false
48+
4649
# API keys for official Generative Language API
4750
#generative-language-api-key:
4851
# - "AIzaSy...01"

internal/access/config_access/provider.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,12 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
5757
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
5858
authHeaderAnthropic := r.Header.Get("X-Api-Key")
5959
queryKey := ""
60+
queryAuthToken := ""
6061
if r.URL != nil {
6162
queryKey = r.URL.Query().Get("key")
63+
queryAuthToken = r.URL.Query().Get("auth_token")
6264
}
63-
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" {
65+
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" {
6466
return nil, sdkaccess.ErrNoCredentials
6567
}
6668

@@ -74,6 +76,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
7476
{authHeaderGoogle, "x-goog-api-key"},
7577
{authHeaderAnthropic, "x-api-key"},
7678
{queryKey, "query-key"},
79+
{queryAuthToken, "query-auth-token"},
7780
}
7881

7982
for _, candidate := range candidates {

internal/api/middleware/request_logging.go

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/gin-gonic/gin"
1212
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
13+
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
1314
)
1415

1516
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
@@ -63,13 +64,11 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
6364
// It captures the URL, method, headers, and body. The request body is read and then
6465
// restored so that it can be processed by subsequent handlers.
6566
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
66-
// Capture URL
67-
url := c.Request.URL.String()
68-
if c.Request.URL.Path != "" {
69-
url = c.Request.URL.Path
70-
if c.Request.URL.RawQuery != "" {
71-
url += "?" + c.Request.URL.RawQuery
72-
}
67+
// Capture URL with sensitive query parameters masked
68+
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
69+
url := c.Request.URL.Path
70+
if maskedQuery != "" {
71+
url += "?" + maskedQuery
7372
}
7473

7574
// Capture method

internal/api/server.go

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,10 @@ type Server struct {
140140
currentPath string
141141

142142
// wsRoutes tracks registered websocket upgrade paths.
143-
wsRouteMu sync.Mutex
144-
wsRoutes map[string]struct{}
143+
wsRouteMu sync.Mutex
144+
wsRoutes map[string]struct{}
145+
wsAuthChanged func(bool, bool)
146+
wsAuthEnabled atomic.Bool
145147

146148
// management handler
147149
mgmt *managementHandlers.Handler
@@ -235,6 +237,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
235237
envManagementSecret: envManagementSecret,
236238
wsRoutes: make(map[string]struct{}),
237239
}
240+
s.wsAuthEnabled.Store(cfg.WebsocketAuth)
238241
// Save initial YAML snapshot
239242
s.oldConfigYaml, _ = yaml.Marshal(cfg)
240243
s.applyAccessConfig(nil, cfg)
@@ -398,10 +401,20 @@ func (s *Server) AttachWebsocketRoute(path string, handler http.Handler) {
398401
s.wsRoutes[trimmed] = struct{}{}
399402
s.wsRouteMu.Unlock()
400403

401-
s.engine.GET(trimmed, func(c *gin.Context) {
404+
authMiddleware := AuthMiddleware(s.accessManager)
405+
conditionalAuth := func(c *gin.Context) {
406+
if !s.wsAuthEnabled.Load() {
407+
c.Next()
408+
return
409+
}
410+
authMiddleware(c)
411+
}
412+
finalHandler := func(c *gin.Context) {
402413
handler.ServeHTTP(c.Writer, c.Request)
403414
c.Abort()
404-
})
415+
}
416+
417+
s.engine.GET(trimmed, conditionalAuth, finalHandler)
405418
}
406419

407420
func (s *Server) registerManagementRoutes() {
@@ -803,6 +816,10 @@ func (s *Server) UpdateClients(cfg *config.Config) {
803816

804817
s.applyAccessConfig(oldCfg, cfg)
805818
s.cfg = cfg
819+
s.wsAuthEnabled.Store(cfg.WebsocketAuth)
820+
if oldCfg != nil && s.wsAuthChanged != nil && oldCfg.WebsocketAuth != cfg.WebsocketAuth {
821+
s.wsAuthChanged(oldCfg.WebsocketAuth, cfg.WebsocketAuth)
822+
}
806823
managementasset.SetCurrentConfig(cfg)
807824
// Save YAML snapshot for next comparison
808825
s.oldConfigYaml, _ = yaml.Marshal(cfg)
@@ -843,6 +860,13 @@ func (s *Server) UpdateClients(cfg *config.Config) {
843860
)
844861
}
845862

863+
func (s *Server) SetWebsocketAuthChangeHandler(fn func(bool, bool)) {
864+
if s == nil {
865+
return
866+
}
867+
s.wsAuthChanged = fn
868+
}
869+
846870
// (management handlers moved to internal/api/handlers/management)
847871

848872
// AuthMiddleware returns a Gin middleware handler that authenticates requests

internal/config/config.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ type Config struct {
4040
// QuotaExceeded defines the behavior when a quota is exceeded.
4141
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"`
4242

43+
// WebsocketAuth enables or disables authentication for the WebSocket API.
44+
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
45+
4346
// GlAPIKey is the API key for the generative language API.
4447
GlAPIKey []string `yaml:"generative-language-api-key" json:"generative-language-api-key"`
4548

internal/logging/gin_logger.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"time"
1111

1212
"github.com/gin-gonic/gin"
13+
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
1314
log "github.com/sirupsen/logrus"
1415
)
1516

@@ -23,7 +24,7 @@ func GinLogrusLogger() gin.HandlerFunc {
2324
return func(c *gin.Context) {
2425
start := time.Now()
2526
path := c.Request.URL.Path
26-
raw := c.Request.URL.RawQuery
27+
raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
2728

2829
c.Next()
2930

internal/util/provider.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
package util
55

66
import (
7+
"net/url"
78
"strings"
89

910
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
@@ -188,3 +189,56 @@ func MaskSensitiveHeaderValue(key, value string) string {
188189
return value
189190
}
190191
}
192+
193+
// MaskSensitiveQuery masks sensitive query parameters, e.g. auth_token, within the raw query string.
194+
func MaskSensitiveQuery(raw string) string {
195+
if raw == "" {
196+
return ""
197+
}
198+
parts := strings.Split(raw, "&")
199+
changed := false
200+
for i, part := range parts {
201+
if part == "" {
202+
continue
203+
}
204+
keyPart := part
205+
valuePart := ""
206+
if idx := strings.Index(part, "="); idx >= 0 {
207+
keyPart = part[:idx]
208+
valuePart = part[idx+1:]
209+
}
210+
decodedKey, err := url.QueryUnescape(keyPart)
211+
if err != nil {
212+
decodedKey = keyPart
213+
}
214+
if !shouldMaskQueryParam(decodedKey) {
215+
continue
216+
}
217+
decodedValue, err := url.QueryUnescape(valuePart)
218+
if err != nil {
219+
decodedValue = valuePart
220+
}
221+
masked := HideAPIKey(strings.TrimSpace(decodedValue))
222+
parts[i] = keyPart + "=" + url.QueryEscape(masked)
223+
changed = true
224+
}
225+
if !changed {
226+
return raw
227+
}
228+
return strings.Join(parts, "&")
229+
}
230+
231+
func shouldMaskQueryParam(key string) bool {
232+
key = strings.ToLower(strings.TrimSpace(key))
233+
if key == "" {
234+
return false
235+
}
236+
key = strings.TrimSuffix(key, "[]")
237+
if key == "key" || strings.Contains(key, "api-key") || strings.Contains(key, "apikey") || strings.Contains(key, "api_key") {
238+
return true
239+
}
240+
if strings.Contains(key, "token") || strings.Contains(key, "secret") {
241+
return true
242+
}
243+
return false
244+
}

internal/watcher/watcher.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,6 +1204,9 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
12041204
if oldCfg.ProxyURL != newCfg.ProxyURL {
12051205
changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL))
12061206
}
1207+
if oldCfg.WebsocketAuth != newCfg.WebsocketAuth {
1208+
changes = append(changes, fmt.Sprintf("ws-auth: %t -> %t", oldCfg.WebsocketAuth, newCfg.WebsocketAuth))
1209+
}
12071210

12081211
// Quota-exceeded behavior
12091212
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {

sdk/cliproxy/service.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,22 @@ func (s *Service) Run(ctx context.Context) error {
421421
s.ensureWebsocketGateway()
422422
if s.server != nil && s.wsGateway != nil {
423423
s.server.AttachWebsocketRoute(s.wsGateway.Path(), s.wsGateway.Handler())
424+
s.server.SetWebsocketAuthChangeHandler(func(oldEnabled, newEnabled bool) {
425+
if oldEnabled == newEnabled {
426+
return
427+
}
428+
if !oldEnabled && newEnabled {
429+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
430+
defer cancel()
431+
if errStop := s.wsGateway.Stop(ctx); errStop != nil {
432+
log.Warnf("failed to reset websocket connections after ws-auth change %t -> %t: %v", oldEnabled, newEnabled, errStop)
433+
return
434+
}
435+
log.Debugf("ws-auth enabled; existing websocket sessions terminated to enforce authentication")
436+
return
437+
}
438+
log.Debugf("ws-auth disabled; existing websocket sessions remain connected")
439+
})
424440
}
425441

426442
if s.hooks.OnBeforeStart != nil {
@@ -460,7 +476,6 @@ func (s *Service) Run(ctx context.Context) error {
460476
s.cfg = newCfg
461477
s.cfgMu.Unlock()
462478
s.rebindExecutors()
463-
464479
}
465480

466481
watcherWrapper, err = s.watcherFactory(s.configPath, s.cfg.AuthDir, reloadCallback)

0 commit comments

Comments
 (0)