Skip to content

Commit 875dcce

Browse files
committed
feat(flagd): Add authorityOverride and clientInterceptors
Signed-off-by: Maks Osowski <[email protected]>
1 parent ec12ef7 commit 875dcce

File tree

5 files changed

+155
-0
lines changed

5 files changed

+155
-0
lines changed

providers/flagd/README.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,50 @@ FlagdProvider flagdProvider = new FlagdProvider(
180180
> There's a [vulnerability](https://security.snyk.io/vuln/SNYK-JAVA-IONETTY-1042268) in [netty](https://github.com/netty/netty), a transitive dependency of the underlying gRPC libraries used in the flagd-provider that fails to correctly validate certificates.
181181
> This will be addressed in netty v5.
182182
183+
### Configuring gRPC credentials and headers
184+
185+
The `clientInterceptors` and `authorityOverride` are meant for connection of the in-process resolver to a Sync API implementation on a host/port, that might require special credentials or headers.
186+
187+
```java
188+
private static ClientInterceptor createHeaderInterceptor() {
189+
return new ClientInterceptor() {
190+
@Override
191+
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
192+
return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) {
193+
@Override
194+
public void start(Listener<RespT> responseListener, Metadata headers) {
195+
headers.put(Metadata.Key.of("custom-header", Metadata.ASCII_STRING_MARSHALLER), "header-value");
196+
super.start(responseListener, headers);
197+
}
198+
};
199+
}
200+
};
201+
}
202+
203+
private static ClientInterceptor createCallCrednetialsInterceptor(CallCredentials callCredentials) throws IOException {
204+
return new ClientInterceptor() {
205+
@Override
206+
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
207+
return next.newCall(method, callOptions.withCallCredentials(callCredentials));
208+
}
209+
};
210+
}
211+
212+
List<ClientInterceptor> clientInterceptors = new ArrayList<ClientInterceptor>(2);
213+
clientInterceptors.add(createHeaderInterceptor());
214+
CallCredentials myCallCredentals = ...;
215+
clientInterceptors.add(createCallCrednetialsInterceptor(myCallCredentials));
216+
217+
FlagdProvider flagdProvider = new FlagdProvider(
218+
FlagdOptions.builder()
219+
.host("example.com/flagdSyncApi")
220+
.port(443)
221+
.tls(true)
222+
.overrideAuthority("authority-host.sync.example.com")
223+
.clientInterceptors(clientInterceptors)
224+
.build());
225+
```
226+
183227
### Caching (RPC only)
184228

185229
> [!NOTE]

providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/FlagdOptions.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
import dev.openfeature.sdk.EvaluationContext;
88
import dev.openfeature.sdk.ImmutableContext;
99
import dev.openfeature.sdk.Structure;
10+
import io.grpc.ClientInterceptor;
1011
import io.opentelemetry.api.GlobalOpenTelemetry;
1112
import io.opentelemetry.api.OpenTelemetry;
13+
import java.util.List;
1214
import java.util.function.Function;
1315
import lombok.Builder;
1416
import lombok.Getter;
@@ -164,6 +166,18 @@ public class FlagdOptions {
164166
*/
165167
private OpenTelemetry openTelemetry;
166168

169+
/**
170+
* gRPC client interceptors to be used when creating a gRPC channel.
171+
*/
172+
@Builder.Default
173+
private List<ClientInterceptor> clientInterceptors = null;
174+
175+
/**
176+
* Authority header to be used when creating a gRPC channel.
177+
*/
178+
@Builder.Default
179+
private String authorityOverride = null;
180+
167181
/**
168182
* Builder overwrite in order to customize the "build" method.
169183
*

providers/flagd/src/main/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilder.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,12 @@ public static ManagedChannel nettyChannel(final FlagdOptions options) {
6363
final NettyChannelBuilder builder =
6464
NettyChannelBuilder.forTarget(targetUri).keepAliveTime(keepAliveMs, TimeUnit.MILLISECONDS);
6565

66+
if (options.getAuthorityOverride() != null) {
67+
builder.overrideAuthority(options.getAuthorityOverride());
68+
}
69+
if (options.getClientInterceptors() != null) {
70+
builder.intercept(options.getClientInterceptors());
71+
}
6672
if (options.isTls()) {
6773
SslContextBuilder sslContext = GrpcSslContexts.forClient();
6874

providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/FlagdOptionsTest.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020

2121
import dev.openfeature.contrib.providers.flagd.resolver.process.storage.MockConnector;
2222
import dev.openfeature.contrib.providers.flagd.resolver.process.storage.connector.Connector;
23+
import io.grpc.ClientInterceptor;
2324
import io.opentelemetry.api.OpenTelemetry;
25+
import java.util.ArrayList;
26+
import java.util.List;
2427
import java.util.function.Function;
2528
import org.junit.jupiter.api.Nested;
2629
import org.junit.jupiter.api.Test;
@@ -46,12 +49,15 @@ void TestDefaults() {
4649
assertNull(builder.getOfflineFlagSourcePath());
4750
assertEquals(Resolver.RPC, builder.getResolverType());
4851
assertEquals(0, builder.getKeepAlive());
52+
assertNull(builder.getAuthorityOverride());
53+
assertNull(builder.getClientInterceptors());
4954
}
5055

5156
@Test
5257
void TestBuilderOptions() {
5358
OpenTelemetry openTelemetry = Mockito.mock(OpenTelemetry.class);
5459
Connector connector = new MockConnector(null);
60+
List<ClientInterceptor> clientInterceptors = new ArrayList<ClientInterceptor>();
5561

5662
FlagdOptions flagdOptions = FlagdOptions.builder()
5763
.host("https://hosted-flagd")
@@ -66,6 +72,8 @@ void TestBuilderOptions() {
6672
.resolverType(Resolver.IN_PROCESS)
6773
.targetUri("dns:///localhost:8016")
6874
.keepAlive(1000)
75+
.authorityOverride("test-authority.sync.example.com")
76+
.clientInterceptors(clientInterceptors)
6977
.build();
7078

7179
assertEquals("https://hosted-flagd", flagdOptions.getHost());
@@ -80,6 +88,8 @@ void TestBuilderOptions() {
8088
assertEquals(Resolver.IN_PROCESS, flagdOptions.getResolverType());
8189
assertEquals("dns:///localhost:8016", flagdOptions.getTargetUri());
8290
assertEquals(1000, flagdOptions.getKeepAlive());
91+
assertEquals("test-authority.sync.example.com", flagdOptions.getAuthorityOverride());
92+
assertEquals(clientInterceptors, flagdOptions.getClientInterceptors());
8393
}
8494

8595
@Test

providers/flagd/src/test/java/dev/openfeature/contrib/providers/flagd/resolver/common/ChannelBuilderTest.java

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import static org.assertj.core.api.Assertions.assertThat;
44
import static org.assertj.core.api.Assertions.assertThatThrownBy;
5+
import static org.mockito.ArgumentMatchers.anyList;
56
import static org.mockito.Mockito.any;
67
import static org.mockito.Mockito.anyLong;
78
import static org.mockito.Mockito.anyString;
@@ -11,6 +12,7 @@
1112
import static org.mockito.Mockito.when;
1213

1314
import dev.openfeature.contrib.providers.flagd.FlagdOptions;
15+
import io.grpc.ClientInterceptor;
1416
import io.grpc.ManagedChannel;
1517
import io.grpc.netty.GrpcSslContexts;
1618
import io.grpc.netty.NettyChannelBuilder;
@@ -20,6 +22,8 @@
2022
import io.netty.channel.unix.DomainSocketAddress;
2123
import io.netty.handler.ssl.SslContextBuilder;
2224
import java.io.File;
25+
import java.util.ArrayList;
26+
import java.util.List;
2327
import java.util.concurrent.TimeUnit;
2428
import javax.net.ssl.SSLKeyException;
2529
import org.junit.jupiter.api.Test;
@@ -113,6 +117,83 @@ void testNettyChannel_withTlsAndCert() {
113117
}
114118
}
115119

120+
@Test
121+
void testNettyChannel_withAuthorityOverride() {
122+
try (MockedStatic<NettyChannelBuilder> nettyMock = mockStatic(NettyChannelBuilder.class)) {
123+
// Mocks
124+
NettyChannelBuilder mockBuilder = mock(NettyChannelBuilder.class);
125+
ManagedChannel mockChannel = mock(ManagedChannel.class);
126+
nettyMock
127+
.when(() -> NettyChannelBuilder.forTarget("localhost:8080"))
128+
.thenReturn(mockBuilder);
129+
130+
when(mockBuilder.keepAliveTime(anyLong(), any(TimeUnit.class))).thenReturn(mockBuilder);
131+
when(mockBuilder.sslContext(any())).thenReturn(mockBuilder);
132+
when(mockBuilder.overrideAuthority(anyString())).thenReturn(mockBuilder);
133+
when(mockBuilder.build()).thenReturn(mockChannel);
134+
135+
// Input options
136+
FlagdOptions options = FlagdOptions.builder()
137+
.host("localhost")
138+
.port(8080)
139+
.keepAlive(5000)
140+
.tls(true)
141+
.authorityOverride("test-authority.sync.example.com")
142+
.build();
143+
144+
// Call method under test
145+
ManagedChannel channel = ChannelBuilder.nettyChannel(options);
146+
147+
// Assertions
148+
assertThat(channel).isEqualTo(mockChannel);
149+
nettyMock.verify(() -> NettyChannelBuilder.forTarget("localhost:8080"));
150+
verify(mockBuilder).keepAliveTime(5000, TimeUnit.MILLISECONDS);
151+
verify(mockBuilder).sslContext(any());
152+
verify(mockBuilder).overrideAuthority("test-authority.sync.example.com");
153+
verify(mockBuilder).build();
154+
}
155+
}
156+
157+
@Test
158+
void testNettyChannel_withClientInterceptors() {
159+
try (MockedStatic<NettyChannelBuilder> nettyMock = mockStatic(NettyChannelBuilder.class)) {
160+
// Mocks
161+
NettyChannelBuilder mockBuilder = mock(NettyChannelBuilder.class);
162+
ManagedChannel mockChannel = mock(ManagedChannel.class);
163+
nettyMock
164+
.when(() -> NettyChannelBuilder.forTarget("localhost:8080"))
165+
.thenReturn(mockBuilder);
166+
167+
when(mockBuilder.keepAliveTime(anyLong(), any(TimeUnit.class))).thenReturn(mockBuilder);
168+
when(mockBuilder.sslContext(any())).thenReturn(mockBuilder);
169+
when(mockBuilder.intercept(anyList())).thenReturn(mockBuilder);
170+
when(mockBuilder.build()).thenReturn(mockChannel);
171+
172+
List<ClientInterceptor> clientInterceptors = new ArrayList<ClientInterceptor>();
173+
clientInterceptors.add(mock(ClientInterceptor.class));
174+
175+
// Input options
176+
FlagdOptions options = FlagdOptions.builder()
177+
.host("localhost")
178+
.port(8080)
179+
.keepAlive(5000)
180+
.tls(true)
181+
.clientInterceptors(clientInterceptors)
182+
.build();
183+
184+
// Call method under test
185+
ManagedChannel channel = ChannelBuilder.nettyChannel(options);
186+
187+
// Assertions
188+
assertThat(channel).isEqualTo(mockChannel);
189+
nettyMock.verify(() -> NettyChannelBuilder.forTarget("localhost:8080"));
190+
verify(mockBuilder).keepAliveTime(5000, TimeUnit.MILLISECONDS);
191+
verify(mockBuilder).sslContext(any());
192+
verify(mockBuilder).intercept(clientInterceptors);
193+
verify(mockBuilder).build();
194+
}
195+
}
196+
116197
@ParameterizedTest
117198
@ValueSource(strings = {"/incorrect/{uri}/;)"})
118199
void testNettyChannel_withInvalidTargetUri(String uri) {

0 commit comments

Comments
 (0)