11package proxy
22
33import (
4+ "context"
45 "encoding/json"
56 "fmt"
67 "io"
@@ -37,8 +38,11 @@ type Proxy struct {
3738 currentIndex int
3839 mu sync.RWMutex
3940 server * http.Server
40- activeRequests map [string ]bool // tracks active requests by endpoint name
41- activeRequestsMu sync.RWMutex // protects activeRequests map
41+ activeRequests map [string ]bool // tracks active requests by endpoint name
42+ activeRequestsMu sync.RWMutex // protects activeRequests map
43+ endpointCtx map [string ]context.Context // context per endpoint for cancellation
44+ endpointCancel map [string ]context.CancelFunc // cancel functions per endpoint
45+ ctxMu sync.RWMutex // protects context maps
4246}
4347
4448// New creates a new Proxy instance
@@ -50,6 +54,8 @@ func New(cfg *config.Config, statsStorage StatsStorage, deviceID string) *Proxy
5054 stats : stats ,
5155 currentIndex : 0 ,
5256 activeRequests : make (map [string ]bool ),
57+ endpointCtx : make (map [string ]context.Context ),
58+ endpointCancel : make (map [string ]context.CancelFunc ),
5359 }
5460}
5561
@@ -136,6 +142,33 @@ func (p *Proxy) isCurrentEndpoint(endpointName string) bool {
136142 return current .Name == endpointName
137143}
138144
145+ // getEndpointContext returns a context for the given endpoint, creating one if needed
146+ func (p * Proxy ) getEndpointContext (endpointName string ) context.Context {
147+ p .ctxMu .Lock ()
148+ defer p .ctxMu .Unlock ()
149+
150+ if ctx , ok := p .endpointCtx [endpointName ]; ok {
151+ return ctx
152+ }
153+
154+ ctx , cancel := context .WithCancel (context .Background ())
155+ p .endpointCtx [endpointName ] = ctx
156+ p .endpointCancel [endpointName ] = cancel
157+ return ctx
158+ }
159+
160+ // cancelEndpointRequests cancels all requests for the given endpoint
161+ func (p * Proxy ) cancelEndpointRequests (endpointName string ) {
162+ p .ctxMu .Lock ()
163+ defer p .ctxMu .Unlock ()
164+
165+ if cancel , ok := p .endpointCancel [endpointName ]; ok {
166+ cancel ()
167+ delete (p .endpointCtx , endpointName )
168+ delete (p .endpointCancel , endpointName )
169+ }
170+ }
171+
139172// rotateEndpoint switches to the next endpoint (thread-safe)
140173// waitForActive: if true, waits briefly for active requests to complete before switching
141174func (p * Proxy ) rotateEndpoint () config.Endpoint {
@@ -189,7 +222,7 @@ func (p *Proxy) GetCurrentEndpointName() string {
189222
190223// SetCurrentEndpoint manually switches to a specific endpoint by name
191224// Returns error if endpoint not found or not enabled
192- // Thread-safe and won't affect ongoing requests
225+ // Thread-safe and cancels ongoing requests on the old endpoint
193226func (p * Proxy ) SetCurrentEndpoint (targetName string ) error {
194227 p .mu .Lock ()
195228 defer p .mu .Unlock ()
@@ -203,6 +236,10 @@ func (p *Proxy) SetCurrentEndpoint(targetName string) error {
203236 for i , ep := range endpoints {
204237 if ep .Name == targetName {
205238 oldEndpoint := endpoints [p .currentIndex % len (endpoints )]
239+ if oldEndpoint .Name != targetName {
240+ // Cancel all requests on the old endpoint
241+ p .cancelEndpointRequests (oldEndpoint .Name )
242+ }
206243 p .currentIndex = i
207244 logger .Info ("[MANUAL SWITCH] %s → %s" , oldEndpoint .Name , ep .Name )
208245 return nil
@@ -266,6 +303,7 @@ func (p *Proxy) handleProxy(w http.ResponseWriter, r *http.Request) {
266303
267304 maxRetries := len (endpoints ) * 2
268305 endpointAttempts := 0
306+ lastEndpointName := ""
269307
270308 for retry := 0 ; retry < maxRetries ; retry ++ {
271309 endpoint := p .getCurrentEndpoint ()
@@ -274,6 +312,12 @@ func (p *Proxy) handleProxy(w http.ResponseWriter, r *http.Request) {
274312 return
275313 }
276314
315+ // Reset attempts counter if endpoint changed (e.g., manual switch)
316+ if lastEndpointName != "" && lastEndpointName != endpoint .Name {
317+ endpointAttempts = 0
318+ }
319+ lastEndpointName = endpoint .Name
320+
277321 endpointAttempts ++
278322 p .markRequestActive (endpoint .Name )
279323 p .stats .RecordRequest (endpoint .Name )
@@ -336,7 +380,8 @@ func (p *Proxy) handleProxy(w http.ResponseWriter, r *http.Request) {
336380 continue
337381 }
338382
339- resp , err := sendRequest (proxyReq )
383+ ctx := p .getEndpointContext (endpoint .Name )
384+ resp , err := sendRequest (ctx , proxyReq )
340385 if err != nil {
341386 logger .Error ("[%s] Request failed: %v" , endpoint .Name , err )
342387 p .stats .RecordError (endpoint .Name )
0 commit comments