Skip to content

Commit 673631e

Browse files
committed
fix: cancel old endpoint requests on manual switch
1 parent 5edd578 commit 673631e

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

internal/proxy/proxy.go

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package proxy
22

33
import (
4+
"context"
45
"encoding/json"
56
"fmt"
67
"io"
@@ -37,8 +38,11 @@ type Proxy struct {
3738
currentIndex int
3839
mu sync.RWMutex
3940
server *http.Server
40-
activeRequests map[string]bool // tracks active requests by endpoint name
41-
activeRequestsMu sync.RWMutex // protects activeRequests map
41+
activeRequests map[string]bool // tracks active requests by endpoint name
42+
activeRequestsMu sync.RWMutex // protects activeRequests map
43+
endpointCtx map[string]context.Context // context per endpoint for cancellation
44+
endpointCancel map[string]context.CancelFunc // cancel functions per endpoint
45+
ctxMu sync.RWMutex // protects context maps
4246
}
4347

4448
// New creates a new Proxy instance
@@ -50,6 +54,8 @@ func New(cfg *config.Config, statsStorage StatsStorage, deviceID string) *Proxy
5054
stats: stats,
5155
currentIndex: 0,
5256
activeRequests: make(map[string]bool),
57+
endpointCtx: make(map[string]context.Context),
58+
endpointCancel: make(map[string]context.CancelFunc),
5359
}
5460
}
5561

@@ -136,6 +142,33 @@ func (p *Proxy) isCurrentEndpoint(endpointName string) bool {
136142
return current.Name == endpointName
137143
}
138144

145+
// getEndpointContext returns a context for the given endpoint, creating one if needed
146+
func (p *Proxy) getEndpointContext(endpointName string) context.Context {
147+
p.ctxMu.Lock()
148+
defer p.ctxMu.Unlock()
149+
150+
if ctx, ok := p.endpointCtx[endpointName]; ok {
151+
return ctx
152+
}
153+
154+
ctx, cancel := context.WithCancel(context.Background())
155+
p.endpointCtx[endpointName] = ctx
156+
p.endpointCancel[endpointName] = cancel
157+
return ctx
158+
}
159+
160+
// cancelEndpointRequests cancels all requests for the given endpoint
161+
func (p *Proxy) cancelEndpointRequests(endpointName string) {
162+
p.ctxMu.Lock()
163+
defer p.ctxMu.Unlock()
164+
165+
if cancel, ok := p.endpointCancel[endpointName]; ok {
166+
cancel()
167+
delete(p.endpointCtx, endpointName)
168+
delete(p.endpointCancel, endpointName)
169+
}
170+
}
171+
139172
// rotateEndpoint switches to the next endpoint (thread-safe)
140173
// waitForActive: if true, waits briefly for active requests to complete before switching
141174
func (p *Proxy) rotateEndpoint() config.Endpoint {
@@ -189,7 +222,7 @@ func (p *Proxy) GetCurrentEndpointName() string {
189222

190223
// SetCurrentEndpoint manually switches to a specific endpoint by name
191224
// Returns error if endpoint not found or not enabled
192-
// Thread-safe and won't affect ongoing requests
225+
// Thread-safe and cancels ongoing requests on the old endpoint
193226
func (p *Proxy) SetCurrentEndpoint(targetName string) error {
194227
p.mu.Lock()
195228
defer p.mu.Unlock()
@@ -203,6 +236,10 @@ func (p *Proxy) SetCurrentEndpoint(targetName string) error {
203236
for i, ep := range endpoints {
204237
if ep.Name == targetName {
205238
oldEndpoint := endpoints[p.currentIndex%len(endpoints)]
239+
if oldEndpoint.Name != targetName {
240+
// Cancel all requests on the old endpoint
241+
p.cancelEndpointRequests(oldEndpoint.Name)
242+
}
206243
p.currentIndex = i
207244
logger.Info("[MANUAL SWITCH] %s → %s", oldEndpoint.Name, ep.Name)
208245
return nil
@@ -266,6 +303,7 @@ func (p *Proxy) handleProxy(w http.ResponseWriter, r *http.Request) {
266303

267304
maxRetries := len(endpoints) * 2
268305
endpointAttempts := 0
306+
lastEndpointName := ""
269307

270308
for retry := 0; retry < maxRetries; retry++ {
271309
endpoint := p.getCurrentEndpoint()
@@ -274,6 +312,12 @@ func (p *Proxy) handleProxy(w http.ResponseWriter, r *http.Request) {
274312
return
275313
}
276314

315+
// Reset attempts counter if endpoint changed (e.g., manual switch)
316+
if lastEndpointName != "" && lastEndpointName != endpoint.Name {
317+
endpointAttempts = 0
318+
}
319+
lastEndpointName = endpoint.Name
320+
277321
endpointAttempts++
278322
p.markRequestActive(endpoint.Name)
279323
p.stats.RecordRequest(endpoint.Name)
@@ -336,7 +380,8 @@ func (p *Proxy) handleProxy(w http.ResponseWriter, r *http.Request) {
336380
continue
337381
}
338382

339-
resp, err := sendRequest(proxyReq)
383+
ctx := p.getEndpointContext(endpoint.Name)
384+
resp, err := sendRequest(ctx, proxyReq)
340385
if err != nil {
341386
logger.Error("[%s] Request failed: %v", endpoint.Name, err)
342387
p.stats.RecordError(endpoint.Name)

internal/proxy/request.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package proxy
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/json"
67
"fmt"
78
"net/http"
@@ -198,7 +199,8 @@ func buildProxyRequest(r *http.Request, endpoint config.Endpoint, transformedBod
198199
}
199200

200201
// sendRequest sends the HTTP request and returns the response
201-
func sendRequest(proxyReq *http.Request) (*http.Response, error) {
202+
func sendRequest(ctx context.Context, proxyReq *http.Request) (*http.Response, error) {
203+
proxyReq = proxyReq.WithContext(ctx)
202204
client := &http.Client{
203205
Timeout: 300 * time.Second,
204206
}

0 commit comments

Comments
 (0)