Skip to content

Commit 878a8cf

Browse files
Fixed tests for CircuitBreakerInterceptor
1 parent 6a058f7 commit 878a8cf

File tree

3 files changed

+137
-134
lines changed

3 files changed

+137
-134
lines changed

grpc-circuitbreaker-utils/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies {
2121

2222
testImplementation("org.junit.jupiter:junit-jupiter:5.8.2")
2323
testImplementation("org.mockito:mockito-core:5.8.0")
24+
testImplementation("org.mockito:mockito-junit-jupiter:5.8.0")
2425
}
2526

2627
tasks.test {

grpc-circuitbreaker-utils/src/main/java/org/hypertrace/circuitbreaker/grpcutils/resilience/ResilienceCircuitBreakerInterceptor.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,15 @@ public void start(Listener<RespT> responseListener, Metadata headers) {
8888
public void sendMessage(ReqT message) {
8989
CircuitBreakerConfiguration<ReqT> config =
9090
(CircuitBreakerConfiguration<ReqT>) circuitBreakerConfiguration;
91-
if (config.getRequestClass() == null
92-
|| (!message.getClass().equals(config.getRequestClass()))) {
93-
log.warn("Invalid config for message type: {}", message.getClass());
94-
super.sendMessage(message);
95-
}
96-
if (config.getKeyFunction() != null) {
97-
circuitBreakerKey = config.getKeyFunction().apply(RequestContext.CURRENT.get(), message);
98-
} else {
91+
if (config.getRequestClass() == null || config.getKeyFunction() == null) {
92+
log.debug("Circuit breaker will apply to all requests as config is not set");
9993
circuitBreakerKey = "default";
94+
} else {
95+
if (!message.getClass().equals(config.getRequestClass())) {
96+
log.warn("Invalid config for message type: {}", message.getClass());
97+
super.sendMessage(message);
98+
}
99+
circuitBreakerKey = config.getKeyFunction().apply(RequestContext.CURRENT.get(), message);
100100
}
101101
circuitBreaker = resilienceCircuitBreakerProvider.getCircuitBreaker(circuitBreakerKey);
102102
if (!circuitBreaker.tryAcquirePermission()) {
Original file line numberDiff line numberDiff line change
@@ -1,168 +1,170 @@
11
package org.hypertrace.circuitbreaker.grpcutils.resilience;
22

3-
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
4-
import static org.junit.jupiter.api.Assertions.assertNotNull;
53
import static org.junit.jupiter.api.Assertions.assertThrows;
64
import static org.mockito.ArgumentMatchers.any;
7-
import static org.mockito.ArgumentMatchers.eq;
8-
import static org.mockito.Mockito.doNothing;
9-
import static org.mockito.Mockito.mock;
10-
import static org.mockito.Mockito.spy;
11-
import static org.mockito.Mockito.verify;
12-
import static org.mockito.Mockito.when;
13-
14-
import com.typesafe.config.Config;
15-
import com.typesafe.config.ConfigFactory;
5+
import static org.mockito.ArgumentMatchers.anyString;
6+
import static org.mockito.Mockito.*;
7+
168
import io.github.resilience4j.circuitbreaker.CircuitBreaker;
179
import io.github.resilience4j.circuitbreaker.CircuitBreakerRegistry;
18-
import io.grpc.CallOptions;
19-
import io.grpc.Channel;
20-
import io.grpc.ClientCall;
21-
import io.grpc.ForwardingClientCall;
22-
import io.grpc.Metadata;
23-
import io.grpc.MethodDescriptor;
24-
import io.grpc.StatusRuntimeException;
10+
import io.grpc.*;
2511
import java.time.Clock;
26-
import org.hypertrace.circuitbreaker.grpcutils.CircuitBreakerConfigParser;
12+
import java.time.Instant;
13+
import java.time.ZoneOffset;
14+
import java.util.concurrent.TimeUnit;
2715
import org.hypertrace.circuitbreaker.grpcutils.CircuitBreakerConfiguration;
28-
import org.hypertrace.core.grpcutils.context.RequestContext;
29-
import org.junit.jupiter.api.Disabled;
16+
import org.junit.jupiter.api.BeforeEach;
3017
import org.junit.jupiter.api.Test;
31-
import org.mockito.Mockito;
18+
import org.junit.jupiter.api.extension.ExtendWith;
19+
import org.mockito.*;
20+
import org.mockito.junit.jupiter.MockitoExtension;
3221

22+
@ExtendWith(MockitoExtension.class)
3323
class ResilienceCircuitBreakerInterceptorTest {
3424

35-
private final Config config =
36-
ConfigFactory.parseString(
37-
"default {\n"
38-
+ " failureRateThreshold=50.0\n"
39-
+ " slowCallRateThreshold=100.0\n"
40-
+ " slowCallDurationThreshold=5s\n"
41-
+ " slidingWindowSize=10\n"
42-
+ " waitDurationInOpenState=1m\n"
43-
+ " minimumNumberOfCalls=5\n"
44-
+ " permittedNumberOfCallsInHalfOpenState=3\n"
45-
+ " slidingWindowType=COUNT_BASED\n"
46-
+ "}");
47-
private final Clock clock = Clock.systemUTC();
48-
private final CircuitBreakerRegistry mockRegistry = Mockito.mock(CircuitBreakerRegistry.class);
49-
private final CircuitBreaker mockCircuitBreaker = Mockito.mock(CircuitBreaker.class);
50-
private final Channel mockChannel = Mockito.mock(Channel.class);
51-
private final ClientCall.Listener<Object> mockListener = mock(ClientCall.Listener.class);
52-
private final ResilienceCircuitBreakerProvider mockCircuitBreakerProvider =
53-
Mockito.mock(ResilienceCircuitBreakerProvider.class);
25+
@Mock private Channel mockChannel;
26+
@Mock private ClientCall<Object, Object> mockClientCall;
27+
@Mock private CircuitBreaker mockCircuitBreaker;
28+
@Mock private Metadata mockMetadata;
29+
@Mock private ClientCall.Listener<Object> mockListener;
30+
@Mock private ResilienceCircuitBreakerProvider mockCircuitBreakerProvider;
31+
@Mock private CircuitBreakerConfiguration<Object> mockCircuitBreakerConfig;
32+
@Mock private CircuitBreakerRegistry mockCircuitBreakerRegistry;
33+
34+
@Mock private Clock fixedClock;
35+
36+
@BeforeEach
37+
void setUp() {
38+
MockitoAnnotations.openMocks(this);
39+
40+
fixedClock = Clock.fixed(Instant.now(), ZoneOffset.UTC);
41+
when(mockChannel.newCall(any(), any())).thenReturn(mockClientCall);
42+
when(mockCircuitBreakerProvider.getCircuitBreaker(anyString())).thenReturn(mockCircuitBreaker);
43+
}
5444

5545
@Test
56-
void testCircuitBreakerEnabled_InterceptsCall() {
57-
MethodDescriptor<Object, Object> methodDescriptor = mock(MethodDescriptor.class);
58-
when(mockCircuitBreakerProvider.getCircuitBreaker("test-key")).thenReturn(mockCircuitBreaker);
59-
CircuitBreakerConfiguration<?> circuitBreakerConfiguration =
60-
CircuitBreakerConfigParser.parseConfig(config)
61-
.enabled(true)
62-
.keyFunction(
63-
(requestContext, request) -> {
64-
GetGithubIntegrationsRequest getGithubIntegrationsRequest =
65-
(GetGithubIntegrationsRequest) request;
66-
return requestContext.getTenantId() + "-" + getGithubIntegrationsRequest.getUrl();
67-
})
68-
.build();
46+
void testSendMessage_CallsSuperSendMessage_Success() {
47+
doNothing().when(mockClientCall).sendMessage(any());
48+
when(mockCircuitBreaker.tryAcquirePermission()).thenReturn(true);
49+
6950
ResilienceCircuitBreakerInterceptor interceptor =
7051
new ResilienceCircuitBreakerInterceptor(
71-
clock, mockRegistry, mockCircuitBreakerProvider, circuitBreakerConfiguration);
52+
fixedClock,
53+
mockCircuitBreakerRegistry,
54+
mockCircuitBreakerProvider,
55+
mockCircuitBreakerConfig);
7256

73-
CallOptions callOptions = Mockito.mock(CallOptions.class);
7457
ClientCall<Object, Object> interceptedCall =
75-
spy(interceptor.interceptCall(methodDescriptor, callOptions, mockChannel));
76-
doNothing().when(interceptedCall).start(any(), any());
77-
assertNotNull(interceptedCall);
78-
assertDoesNotThrow(() -> interceptedCall.start(mockListener, new Metadata()));
79-
verify(interceptedCall).start(eq(mockListener), any(Metadata.class));
58+
interceptor.createInterceptedCall(
59+
mock(MethodDescriptor.class), CallOptions.DEFAULT, mockChannel);
60+
61+
interceptedCall.start(mockListener, mockMetadata);
62+
interceptedCall.sendMessage(new Object());
63+
64+
verify(mockClientCall).sendMessage(any());
8065
}
8166

8267
@Test
83-
void testCircuitBreakerRejectsRequest() {
84-
MethodDescriptor<Object, Object> methodDescriptor = mock(MethodDescriptor.class);
85-
CallOptions callOptions = Mockito.mock(CallOptions.class);
68+
void testSendMessage_CircuitBreakerRejectsRequest() {
8669
when(mockCircuitBreaker.tryAcquirePermission()).thenReturn(false);
8770
when(mockCircuitBreaker.getState()).thenReturn(CircuitBreaker.State.OPEN);
88-
when(mockCircuitBreakerProvider.getCircuitBreaker("tenant1-http://localhost:9000"))
89-
.thenReturn(mockCircuitBreaker);
90-
CircuitBreakerConfiguration<GetGithubIntegrationsRequest> circuitBreakerConfiguration =
91-
CircuitBreakerConfigParser.<GetGithubIntegrationsRequest>parseConfig(config)
92-
.enabled(true)
93-
.requestClass(GetGithubIntegrationsRequest.class)
94-
.keyFunction(
95-
(requestContext, request) -> {
96-
return requestContext.getTenantId().get() + "-" + request.getUrl();
97-
})
98-
.build();
9971
ResilienceCircuitBreakerInterceptor interceptor =
10072
new ResilienceCircuitBreakerInterceptor(
101-
clock, mockRegistry, mockCircuitBreakerProvider, circuitBreakerConfiguration);
73+
fixedClock,
74+
mockCircuitBreakerRegistry,
75+
mockCircuitBreakerProvider,
76+
mockCircuitBreakerConfig);
77+
78+
ClientCall<Object, Object> interceptedCall =
79+
interceptor.createInterceptedCall(
80+
mock(MethodDescriptor.class), CallOptions.DEFAULT, mockChannel);
81+
82+
interceptedCall.start(mockListener, mockMetadata);
83+
84+
assertThrows(
85+
StatusRuntimeException.class,
86+
() -> interceptedCall.sendMessage(new Object()),
87+
"Circuit Breaker should reject request");
88+
89+
verify(mockClientCall, never()).sendMessage(any());
90+
}
91+
92+
@Test
93+
void testSendMessage_CircuitBreakerInHalfOpenState() {
94+
when(mockCircuitBreaker.tryAcquirePermission()).thenReturn(false);
95+
when(mockCircuitBreaker.getState()).thenReturn(CircuitBreaker.State.HALF_OPEN);
96+
ResilienceCircuitBreakerInterceptor interceptor =
97+
new ResilienceCircuitBreakerInterceptor(
98+
fixedClock,
99+
mockCircuitBreakerRegistry,
100+
mockCircuitBreakerProvider,
101+
mockCircuitBreakerConfig);
102102

103103
ClientCall<Object, Object> interceptedCall =
104-
interceptor.interceptCall(methodDescriptor, callOptions, mockChannel);
104+
interceptor.createInterceptedCall(
105+
mock(MethodDescriptor.class), CallOptions.DEFAULT, mockChannel);
106+
107+
interceptedCall.start(mockListener, mockMetadata);
108+
105109
assertThrows(
106110
StatusRuntimeException.class,
107-
() -> {
108-
RequestContext.forTenantId("tenant1")
109-
.call(
110-
() -> {
111-
interceptedCall.sendMessage(
112-
new GetGithubIntegrationsRequest("http://localhost:9000"));
113-
return null;
114-
});
115-
});
111+
() -> interceptedCall.sendMessage(new Object()),
112+
"Circuit Breaker should reject requests when in HALF-OPEN state");
113+
114+
verify(mockClientCall, never()).sendMessage(any());
116115
}
117116

118117
@Test
119-
@Disabled
120-
void testCircuitBreakerSuccess() {
121-
MethodDescriptor<Object, Object> methodDescriptor = mock(MethodDescriptor.class);
122-
CallOptions callOptions = Mockito.mock(CallOptions.class);
118+
void testWrapListenerWithCircuitBreaker_Success() {
123119
when(mockCircuitBreaker.tryAcquirePermission()).thenReturn(true);
124-
when(mockCircuitBreaker.getState()).thenReturn(CircuitBreaker.State.CLOSED);
125-
when(mockCircuitBreakerProvider.getCircuitBreaker("test-key")).thenReturn(mockCircuitBreaker);
126-
CircuitBreakerConfiguration<GetGithubIntegrationsRequest> circuitBreakerConfiguration =
127-
CircuitBreakerConfigParser.<GetGithubIntegrationsRequest>parseConfig(config)
128-
.enabled(true)
129-
.requestClass(GetGithubIntegrationsRequest.class)
130-
.keyFunction(
131-
(requestContext, request) -> {
132-
return requestContext.getTenantId().get() + "-" + request.getUrl();
133-
})
134-
.build();
135120
ResilienceCircuitBreakerInterceptor interceptor =
136-
spy(
137-
new ResilienceCircuitBreakerInterceptor(
138-
clock, mockRegistry, mockCircuitBreakerProvider, circuitBreakerConfiguration));
121+
new ResilienceCircuitBreakerInterceptor(
122+
fixedClock,
123+
mockCircuitBreakerRegistry,
124+
mockCircuitBreakerProvider,
125+
mockCircuitBreakerConfig);
139126

140127
ClientCall<Object, Object> interceptedCall =
141-
interceptor.createInterceptedCall(methodDescriptor, callOptions, mockChannel);
142-
ClientCall<Object, Object> spyCall = spy(interceptedCall);
143-
Mockito.doNothing().when((ForwardingClientCall) interceptedCall).sendMessage(Mockito.any());
144-
// Act
145-
RequestContext.forTenantId("tenant1")
146-
.call(
147-
() -> {
148-
spyCall.sendMessage(new Object());
149-
return null;
150-
});
151-
152-
// Assert
153-
verify(spyCall).sendMessage(any());
154-
verify(mockCircuitBreaker).tryAcquirePermission();
128+
interceptor.createInterceptedCall(
129+
mock(MethodDescriptor.class), CallOptions.DEFAULT, mockChannel);
130+
131+
interceptedCall.start(mockListener, mockMetadata);
132+
interceptedCall.sendMessage(new Object());
133+
134+
// Trigger `onClose` directly to mimic gRPC's flow
135+
ArgumentCaptor<ForwardingClientCallListener<Object>> listenerCaptor =
136+
ArgumentCaptor.forClass(ForwardingClientCallListener.class);
137+
verify(mockClientCall).start(listenerCaptor.capture(), any());
138+
listenerCaptor.getValue().onClose(Status.OK, mockMetadata);
139+
140+
verify(mockClientCall).sendMessage(any());
141+
verify(mockCircuitBreaker).onSuccess(anyLong(), eq(TimeUnit.NANOSECONDS));
155142
}
156143

157-
private static class GetGithubIntegrationsRequest {
158-
private final String url;
144+
@Test
145+
void testWrapListenerWithCircuitBreaker_Failure() {
146+
when(mockCircuitBreaker.tryAcquirePermission()).thenReturn(true);
147+
ResilienceCircuitBreakerInterceptor interceptor =
148+
new ResilienceCircuitBreakerInterceptor(
149+
fixedClock,
150+
mockCircuitBreakerRegistry,
151+
mockCircuitBreakerProvider,
152+
mockCircuitBreakerConfig);
153+
154+
ClientCall<Object, Object> interceptedCall =
155+
interceptor.createInterceptedCall(
156+
mock(MethodDescriptor.class), CallOptions.DEFAULT, mockChannel);
157+
158+
interceptedCall.start(mockListener, mockMetadata);
159+
interceptedCall.sendMessage(new Object());
159160

160-
public GetGithubIntegrationsRequest(String url) {
161-
this.url = url;
162-
}
161+
// Trigger `onClose` directly to mimic gRPC's flow
162+
ArgumentCaptor<ForwardingClientCallListener<Object>> listenerCaptor =
163+
ArgumentCaptor.forClass(ForwardingClientCallListener.class);
164+
verify(mockClientCall).start(listenerCaptor.capture(), any());
165+
listenerCaptor.getValue().onClose(Status.UNKNOWN, mockMetadata);
163166

164-
public String getUrl() {
165-
return url;
166-
}
167+
verify(mockClientCall).sendMessage(any());
168+
verify(mockCircuitBreaker).onError(anyLong(), eq(TimeUnit.NANOSECONDS), any());
167169
}
168170
}

0 commit comments

Comments
 (0)