Skip to content

Commit 967c801

Browse files
committed
initial commit
1 parent e587a1c commit 967c801

File tree

2 files changed

+205
-16
lines changed

2 files changed

+205
-16
lines changed

pkg/gofr/service/circuit_breaker.go

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,33 +60,35 @@ func NewCircuitBreaker(config CircuitBreakerConfig, h HTTP) *circuitBreaker {
6060
// executeWithCircuitBreaker executes the given function with circuit breaker protection.
6161
func (cb *circuitBreaker) executeWithCircuitBreaker(ctx context.Context, f func(ctx context.Context) (*http.Response,
6262
error)) (*http.Response, error) {
63-
cb.mu.Lock()
64-
defer cb.mu.Unlock()
63+
cb.mu.RLock()
64+
isOpen := cb.state == OpenState
65+
cb.mu.RUnlock()
6566

66-
if cb.state == OpenState {
67-
if time.Since(cb.lastChecked) > cb.interval {
68-
// Check health before potentially closing the circuit
69-
if cb.healthCheck(ctx) {
70-
cb.resetCircuit()
71-
return nil, nil
72-
}
67+
if isOpen {
68+
// Circuit is open - try recovery without holding lock
69+
if !cb.tryCircuitRecovery() {
70+
return nil, ErrCircuitOpen
7371
}
74-
75-
return nil, ErrCircuitOpen
72+
// Circuit recovered, proceed with request
7673
}
7774

7875
result, err := f(ctx)
76+
77+
cb.mu.Lock()
78+
defer cb.mu.Unlock()
79+
7980
if err != nil || (result != nil && result.StatusCode > 500) {
8081
cb.handleFailure()
82+
83+
if cb.failureCount > cb.threshold {
84+
cb.openCircuit()
85+
86+
return nil, ErrCircuitOpen
87+
}
8188
} else {
8289
cb.resetFailureCount()
8390
}
8491

85-
if cb.failureCount > cb.threshold {
86-
cb.openCircuit()
87-
return nil, ErrCircuitOpen
88-
}
89-
9092
return result, err
9193
}
9294

pkg/gofr/service/circuit_breaker_test.go

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"io"
66
"net/http"
77
"net/http/httptest"
8+
"sync"
89
"testing"
910
"time"
1011

@@ -927,3 +928,189 @@ func TestCircuitBreaker_HealthEndpointWithTimeout(t *testing.T) {
927928
assert.Equal(t, http.StatusOK, resp.StatusCode)
928929
resp.Body.Close()
929930
}
931+
932+
// TestCircuitBreaker_ParallelExecution tests that requests execute in parallel.
933+
func TestCircuitBreaker_ParallelExecution(t *testing.T) {
934+
requestCount := 0
935+
mu := sync.Mutex{}
936+
937+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
938+
mu.Lock()
939+
940+
requestCount++
941+
942+
mu.Unlock()
943+
944+
time.Sleep(1 * time.Second) // Simulate slow endpoint
945+
w.WriteHeader(http.StatusOK)
946+
947+
_, _ = w.Write([]byte(`{"status": "ok"}`))
948+
}))
949+
defer server.Close()
950+
951+
ctrl := gomock.NewController(t)
952+
mockMetric := NewMockMetrics(ctrl)
953+
954+
mockMetric.EXPECT().RecordHistogram(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
955+
mockMetric.EXPECT().NewCounter(gomock.Any(), gomock.Any()).AnyTimes()
956+
mockMetric.EXPECT().NewGauge(gomock.Any(), gomock.Any()).AnyTimes()
957+
mockMetric.EXPECT().SetGauge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
958+
959+
httpSvc := NewHTTPService(server.URL, logging.NewMockLogger(logging.DEBUG), mockMetric,
960+
&CircuitBreakerConfig{
961+
Threshold: 10,
962+
Interval: 5 * time.Second,
963+
})
964+
965+
startTime := time.Now()
966+
967+
var wg sync.WaitGroup
968+
969+
numRequests := 5
970+
971+
errors := make([]error, numRequests)
972+
973+
// Launch 5 concurrent requests
974+
for i := 0; i < numRequests; i++ {
975+
wg.Add(1)
976+
977+
go func(index int) {
978+
defer wg.Done()
979+
980+
resp, err := httpSvc.Get(t.Context(), "test", nil)
981+
errors[index] = err
982+
983+
if err == nil && resp != nil {
984+
_, _ = io.ReadAll(resp.Body)
985+
986+
_ = resp.Body.Close()
987+
}
988+
}(i)
989+
}
990+
991+
wg.Wait()
992+
993+
totalTime := time.Since(startTime)
994+
995+
// Verify all requests completed successfully
996+
for i := 0; i < numRequests; i++ {
997+
require.NoError(t, errors[i], "Request %d should not error", i)
998+
}
999+
1000+
// All 5 requests should complete in ~2s (parallel)
1001+
assert.Less(t, totalTime, 4*time.Second, "Requests should execute in parallel")
1002+
assert.Equal(t, numRequests, requestCount, "All requests should have been processed")
1003+
}
1004+
1005+
// TestCircuitBreaker_ConcurrentFailures tests thread safety during concurrent failures.
1006+
func TestCircuitBreaker_ConcurrentFailures(t *testing.T) {
1007+
failCount := 0
1008+
mu := sync.Mutex{}
1009+
1010+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1011+
mu.Lock()
1012+
1013+
failCount++
1014+
1015+
current := failCount
1016+
1017+
mu.Unlock()
1018+
1019+
// First 3 requests fail, rest succeed
1020+
if current <= 3 {
1021+
w.WriteHeader(http.StatusServiceUnavailable)
1022+
} else {
1023+
w.WriteHeader(http.StatusOK)
1024+
}
1025+
}))
1026+
defer server.Close()
1027+
1028+
ctrl := gomock.NewController(t)
1029+
mockMetric := NewMockMetrics(ctrl)
1030+
1031+
mockMetric.EXPECT().RecordHistogram(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
1032+
mockMetric.EXPECT().NewCounter(gomock.Any(), gomock.Any()).AnyTimes()
1033+
mockMetric.EXPECT().NewGauge(gomock.Any(), gomock.Any()).AnyTimes()
1034+
mockMetric.EXPECT().SetGauge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
1035+
1036+
httpSvc := NewHTTPService(server.URL, logging.NewMockLogger(logging.DEBUG), mockMetric,
1037+
&CircuitBreakerConfig{
1038+
Threshold: 2,
1039+
Interval: 1 * time.Second,
1040+
})
1041+
1042+
var wg sync.WaitGroup
1043+
1044+
numRequests := 10
1045+
1046+
for i := 0; i < numRequests; i++ {
1047+
wg.Add(1)
1048+
1049+
go func() {
1050+
defer wg.Done()
1051+
1052+
resp, _ := httpSvc.Get(t.Context(), "test", nil)
1053+
if resp != nil {
1054+
_ = resp.Body.Close()
1055+
}
1056+
}()
1057+
}
1058+
1059+
wg.Wait()
1060+
}
1061+
1062+
// TestCircuitBreaker_MixedHTTPMethods tests parallel requests with different HTTP methods.
1063+
func TestCircuitBreaker_MixedHTTPMethods(t *testing.T) {
1064+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
1065+
time.Sleep(1 * time.Second)
1066+
w.WriteHeader(http.StatusOK)
1067+
}))
1068+
defer server.Close()
1069+
1070+
ctrl := gomock.NewController(t)
1071+
mockMetric := NewMockMetrics(ctrl)
1072+
1073+
mockMetric.EXPECT().RecordHistogram(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
1074+
mockMetric.EXPECT().NewCounter(gomock.Any(), gomock.Any()).AnyTimes()
1075+
mockMetric.EXPECT().NewGauge(gomock.Any(), gomock.Any()).AnyTimes()
1076+
mockMetric.EXPECT().SetGauge(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
1077+
1078+
httpSvc := NewHTTPService(server.URL, logging.NewMockLogger(logging.DEBUG), mockMetric,
1079+
&CircuitBreakerConfig{
1080+
Threshold: 5,
1081+
Interval: 2 * time.Second,
1082+
})
1083+
1084+
startTime := time.Now()
1085+
1086+
var wg sync.WaitGroup
1087+
1088+
// Test all HTTP methods in parallel
1089+
methods := []func() (*http.Response, error){
1090+
func() (*http.Response, error) { return httpSvc.Get(t.Context(), "test", nil) },
1091+
func() (*http.Response, error) { return httpSvc.Post(t.Context(), "test", nil, []byte(`{}`)) },
1092+
func() (*http.Response, error) { return httpSvc.Put(t.Context(), "test", nil, []byte(`{}`)) },
1093+
func() (*http.Response, error) { return httpSvc.Patch(t.Context(), "test", nil, []byte(`{}`)) },
1094+
func() (*http.Response, error) { return httpSvc.Delete(t.Context(), "test", []byte(`{}`)) },
1095+
}
1096+
1097+
for _, method := range methods {
1098+
wg.Add(1)
1099+
1100+
go func(fn func() (*http.Response, error)) {
1101+
defer wg.Done()
1102+
1103+
resp, err := fn()
1104+
if err == nil && resp != nil {
1105+
_ = resp.Body.Close()
1106+
}
1107+
}(method)
1108+
}
1109+
1110+
wg.Wait()
1111+
1112+
totalTime := time.Since(startTime)
1113+
1114+
// All 5 methods should complete in ~1s (parallel)
1115+
assert.Less(t, totalTime, 2*time.Second, "Different HTTP methods should execute in parallel")
1116+
}

0 commit comments

Comments
 (0)