Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions internal/adapter/provider/claude/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func ValidateRefreshToken(ctx context.Context, refreshToken string) (*ClaudeToke
// OAuthSession represents an OAuth authorization session
type OAuthSession struct {
State string
TenantID uint64
CodeVerifier string
CreatedAt time.Time
ExpiresAt time.Time
Expand Down Expand Up @@ -113,14 +114,15 @@ func (m *OAuthManager) GenerateState() (string, error) {
}

// CreateSession creates a new OAuth session with PKCE
func (m *OAuthManager) CreateSession(state string) (*OAuthSession, *PKCEChallenge, error) {
func (m *OAuthManager) CreateSession(state string, tenantID uint64) (*OAuthSession, *PKCEChallenge, error) {
pkce, err := GeneratePKCEChallenge()
if err != nil {
return nil, nil, fmt.Errorf("failed to generate PKCE challenge: %w", err)
}

session := &OAuthSession{
State: state,
TenantID: tenantID,
CodeVerifier: pkce.CodeVerifier,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(5 * time.Minute),
Expand Down Expand Up @@ -153,10 +155,22 @@ func (m *OAuthManager) GetSession(state string) (*OAuthSession, bool) {
// CompleteSession completes the OAuth session and broadcasts the result
func (m *OAuthManager) CompleteSession(state string, result *OAuthResult) {
result.State = state

var tenantID uint64
if val, ok := m.sessions.Load(state); ok {
if session, ok := val.(*OAuthSession); ok {
tenantID = session.TenantID
}
}

m.sessions.Delete(state)

if m.broadcaster != nil {
m.broadcaster.BroadcastMessage("claude_oauth_result", result)
if tenantID != 0 {
m.broadcaster.BroadcastMessageToTenant(tenantID, "claude_oauth_result", result)
} else {
m.broadcaster.BroadcastMessage("claude_oauth_result", result)
}
}
}

Expand Down
17 changes: 15 additions & 2 deletions internal/adapter/provider/codex/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ func ValidateRefreshToken(ctx context.Context, refreshToken string) (*CodexToken
// OAuthSession represents an OAuth authorization session
type OAuthSession struct {
State string
TenantID uint64
CodeVerifier string
CreatedAt time.Time
ExpiresAt time.Time
Expand Down Expand Up @@ -137,7 +138,7 @@ func (m *OAuthManager) GenerateState() (string, error) {
}

// CreateSession creates a new OAuth session with PKCE
func (m *OAuthManager) CreateSession(state string) (*OAuthSession, *PKCEChallenge, error) {
func (m *OAuthManager) CreateSession(state string, tenantID uint64) (*OAuthSession, *PKCEChallenge, error) {
// Generate PKCE challenge
pkce, err := GeneratePKCEChallenge()
if err != nil {
Expand All @@ -146,6 +147,7 @@ func (m *OAuthManager) CreateSession(state string) (*OAuthSession, *PKCEChalleng

session := &OAuthSession{
State: state,
TenantID: tenantID,
CodeVerifier: pkce.CodeVerifier,
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(5 * time.Minute), // 5 minute timeout
Expand Down Expand Up @@ -181,12 +183,23 @@ func (m *OAuthManager) CompleteSession(state string, result *OAuthResult) {
// Ensure state matches
result.State = state

var tenantID uint64
if val, ok := m.sessions.Load(state); ok {
if session, ok := val.(*OAuthSession); ok {
tenantID = session.TenantID
}
}

// Delete session
m.sessions.Delete(state)

// Broadcast result via WebSocket
if m.broadcaster != nil {
m.broadcaster.BroadcastMessage("codex_oauth_result", result)
if tenantID != 0 {
m.broadcaster.BroadcastMessageToTenant(tenantID, "codex_oauth_result", result)
} else {
m.broadcaster.BroadcastMessage("codex_oauth_result", result)
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions internal/event/broadcaster.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ type Broadcaster interface {
BroadcastProxyUpstreamAttempt(attempt *domain.ProxyUpstreamAttempt)
BroadcastLog(message string)
BroadcastMessage(messageType string, data interface{})
BroadcastMessageToTenant(tenantID uint64, messageType string, data interface{})
}

// NopBroadcaster 空实现,用于测试或不需要广播的场景
Expand All @@ -18,6 +19,7 @@ func (n *NopBroadcaster) BroadcastProxyRequest(req *domain.ProxyRequest)
func (n *NopBroadcaster) BroadcastProxyUpstreamAttempt(attempt *domain.ProxyUpstreamAttempt) {}
func (n *NopBroadcaster) BroadcastLog(message string) {}
func (n *NopBroadcaster) BroadcastMessage(messageType string, data interface{}) {}
func (n *NopBroadcaster) BroadcastMessageToTenant(tenantID uint64, messageType string, data interface{}) {}

// SanitizeProxyRequestForBroadcast 用于“实时广播”场景瘦身 payload:
// 去掉 request/response 大字段,避免 WebSocket 消息动辄几十/几百 KB,导致前端 JSON.parse / GC 卡死。
Expand Down
8 changes: 8 additions & 0 deletions internal/event/wails_broadcaster_desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ func (w *WailsBroadcaster) BroadcastLog(message string) {
w.emitWailsEvent("log_message", message)
}

// BroadcastMessageToTenant broadcasts a tenant-scoped custom message
func (w *WailsBroadcaster) BroadcastMessageToTenant(tenantID uint64, messageType string, data interface{}) {
if w.inner != nil {
w.inner.BroadcastMessageToTenant(tenantID, messageType, data)
}
w.emitWailsEvent(messageType, data)
}

// BroadcastMessage broadcasts a custom message
func (w *WailsBroadcaster) BroadcastMessage(messageType string, data interface{}) {
if w.inner != nil {
Expand Down
7 changes: 7 additions & 0 deletions internal/event/wails_broadcaster_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ func (w *WailsBroadcaster) BroadcastLog(message string) {
}
}

// BroadcastMessageToTenant broadcasts a tenant-scoped custom message
func (w *WailsBroadcaster) BroadcastMessageToTenant(tenantID uint64, messageType string, data interface{}) {
if w.inner != nil {
w.inner.BroadcastMessageToTenant(tenantID, messageType, data)
}
}

// BroadcastMessage broadcasts a custom message
func (w *WailsBroadcaster) BroadcastMessage(messageType string, data interface{}) {
if w.inner != nil {
Expand Down
7 changes: 4 additions & 3 deletions internal/handler/claude.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,13 @@ type ClaudeOAuthStartResult struct {
}

// StartOAuth starts the OAuth authorization flow
func (h *ClaudeHandler) StartOAuth() (*ClaudeOAuthStartResult, error) {
func (h *ClaudeHandler) StartOAuth(tenantID uint64) (*ClaudeOAuthStartResult, error) {
state, err := h.oauthManager.GenerateState()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
}

_, pkce, err := h.oauthManager.CreateSession(state)
_, pkce, err := h.oauthManager.CreateSession(state, tenantID)
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
Expand Down Expand Up @@ -158,7 +158,8 @@ func (h *ClaudeHandler) handleOAuthStart(w http.ResponseWriter, r *http.Request)
cancel()
}

result, err := h.StartOAuth()
tenantID := maxxctx.GetTenantID(r.Context())
result, err := h.StartOAuth(tenantID)
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
Expand Down
7 changes: 4 additions & 3 deletions internal/handler/codex.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,15 +146,15 @@ type CodexOAuthStartResult struct {
}

// StartOAuth starts the OAuth authorization flow
func (h *CodexHandler) StartOAuth() (*CodexOAuthStartResult, error) {
func (h *CodexHandler) StartOAuth(tenantID uint64) (*CodexOAuthStartResult, error) {
// Generate random state token
state, err := h.oauthManager.GenerateState()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
}

// Create OAuth session with PKCE
_, pkce, err := h.oauthManager.CreateSession(state)
_, pkce, err := h.oauthManager.CreateSession(state, tenantID)
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
Expand Down Expand Up @@ -207,7 +207,8 @@ func (h *CodexHandler) handleOAuthStart(w http.ResponseWriter, r *http.Request)
cancel()
}

result, err := h.StartOAuth()
tenantID := maxxctx.GetTenantID(r.Context())
result, err := h.StartOAuth(tenantID)
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
Expand Down
Loading
Loading