|
3 | 3 | import dev.openfeature.contrib.providers.flagd.FlagdOptions; |
4 | 4 | import dev.openfeature.contrib.providers.flagd.resolver.common.nameresolvers.EnvoyResolverProvider; |
5 | 5 | import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; |
| 6 | +import io.grpc.CallOptions; |
| 7 | +import io.grpc.Channel; |
| 8 | +import io.grpc.ClientCall; |
| 9 | +import io.grpc.ClientInterceptor; |
| 10 | +import io.grpc.ForwardingClientCall; |
6 | 11 | import io.grpc.ManagedChannel; |
| 12 | +import io.grpc.Metadata; |
| 13 | +import io.grpc.MethodDescriptor; |
7 | 14 | import io.grpc.NameResolverRegistry; |
8 | 15 | import io.grpc.Status.Code; |
9 | 16 | import io.grpc.netty.GrpcSslContexts; |
@@ -94,14 +101,19 @@ public static ManagedChannel nettyChannel(final FlagdOptions options) { |
94 | 101 | if (!Epoll.isAvailable()) { |
95 | 102 | throw new IllegalStateException("unix socket cannot be used", Epoll.unavailabilityCause()); |
96 | 103 | } |
97 | | - return NettyChannelBuilder.forAddress(new DomainSocketAddress(options.getSocketPath())) |
| 104 | + var channelBuilder = NettyChannelBuilder.forAddress(new DomainSocketAddress(options.getSocketPath())) |
98 | 105 | .keepAliveTime(keepAliveMs, TimeUnit.MILLISECONDS) |
99 | 106 | .eventLoopGroup(new MultiThreadIoEventLoopGroup(EpollIoHandler.newFactory())) |
100 | 107 | .channelType(EpollDomainSocketChannel.class) |
101 | 108 | .usePlaintext() |
102 | 109 | .defaultServiceConfig(buildRetryPolicy(options)) |
103 | | - .enableRetry() |
104 | | - .build(); |
| 110 | + .enableRetry(); |
| 111 | + |
| 112 | + // add header-based selector interceptor if selector is provided |
| 113 | + if (options.getSelector() != null) { |
| 114 | + channelBuilder.intercept(createSelectorInterceptor(options.getSelector())); |
| 115 | + } |
| 116 | + return channelBuilder.build(); |
105 | 117 | } |
106 | 118 |
|
107 | 119 | // build a TCP socket |
@@ -160,6 +172,30 @@ public static ManagedChannel nettyChannel(final FlagdOptions options) { |
160 | 172 | } |
161 | 173 | } |
162 | 174 |
|
| 175 | + /** |
| 176 | + * Creates a ClientInterceptor that adds the flagd-selector header to gRPC requests. |
| 177 | + * This is the preferred approach for passing selectors as per flagd issue #1814. |
| 178 | + * |
| 179 | + * @param selector the selector value to pass in the header |
| 180 | + * @return a ClientInterceptor that adds the flagd-selector header |
| 181 | + */ |
| 182 | + private static ClientInterceptor createSelectorInterceptor(String selector) { |
| 183 | + return new ClientInterceptor() { |
| 184 | + @Override |
| 185 | + public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( |
| 186 | + MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { |
| 187 | + return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>( |
| 188 | + next.newCall(method, callOptions)) { |
| 189 | + @Override |
| 190 | + public void start(Listener<RespT> responseListener, Metadata headers) { |
| 191 | + headers.put(Metadata.Key.of("flagd-selector", Metadata.ASCII_STRING_MARSHALLER), selector); |
| 192 | + super.start(responseListener, headers); |
| 193 | + } |
| 194 | + }; |
| 195 | + } |
| 196 | + }; |
| 197 | + } |
| 198 | + |
163 | 199 | private static boolean isValidTargetUri(String targetUri) { |
164 | 200 | if (targetUri == null) { |
165 | 201 | return false; |
|
0 commit comments