2525import java .sql .SQLException ;
2626import java .util .ArrayDeque ;
2727import java .util .Deque ;
28+ import java .util .HashMap ;
29+ import java .util .Map ;
2830import java .util .Properties ;
2931import org .apache .arrow .driver .jdbc .authentication .Authentication ;
3032import org .apache .arrow .driver .jdbc .authentication .TokenAuthentication ;
3335import org .apache .arrow .flight .CallHeaders ;
3436import org .apache .arrow .flight .CallInfo ;
3537import org .apache .arrow .flight .CallStatus ;
38+ import org .apache .arrow .flight .FlightMethod ;
3639import org .apache .arrow .flight .FlightServer ;
3740import org .apache .arrow .flight .FlightServerMiddleware ;
3841import org .apache .arrow .flight .Location ;
@@ -67,7 +70,8 @@ public class FlightServerTestExtension
6770 private final CertKeyPair certKeyPair ;
6871 private final File mTlsCACert ;
6972
70- private final MiddlewareCookie .Factory middlewareCookieFactory = new MiddlewareCookie .Factory ();
73+ private final InterceptorMiddleware .Factory interceptorFactory =
74+ new InterceptorMiddleware .Factory ();
7175
7276 private FlightServerTestExtension (
7377 final Properties properties ,
@@ -130,8 +134,8 @@ private void setUseEncryption(boolean useEncryption) {
130134 properties .put ("useEncryption" , useEncryption );
131135 }
132136
133- public MiddlewareCookie .Factory getMiddlewareCookieFactory () {
134- return middlewareCookieFactory ;
137+ public InterceptorMiddleware .Factory getInterceptorFactory () {
138+ return interceptorFactory ;
135139 }
136140
137141 @ FunctionalInterface
@@ -143,7 +147,7 @@ private FlightServer initiateServer(Location location) throws IOException {
143147 FlightServer .Builder builder =
144148 FlightServer .builder (allocator , location , producer )
145149 .headerAuthenticator (authentication .authenticate ())
146- .middleware (FlightServerMiddleware .Key .of ("KEY" ), middlewareCookieFactory );
150+ .middleware (FlightServerMiddleware .Key .of ("KEY" ), interceptorFactory );
147151 if (certKeyPair != null ) {
148152 builder .useTls (certKeyPair .cert , certKeyPair .key );
149153 }
@@ -301,11 +305,11 @@ public FlightServerTestExtension build() {
301305 * A middleware to handle with the cookies in the server. It is used to test if cookies are being
302306 * sent properly.
303307 */
304- static class MiddlewareCookie implements FlightServerMiddleware {
308+ static class InterceptorMiddleware implements FlightServerMiddleware {
305309
306310 private final Factory factory ;
307311
308- public MiddlewareCookie (Factory factory ) {
312+ public InterceptorMiddleware (Factory factory ) {
309313 this .factory = factory ;
310314 }
311315
@@ -323,22 +327,33 @@ public void onCallCompleted(CallStatus callStatus) {}
323327 public void onCallErrored (Throwable throwable ) {}
324328
325329 /** A factory for the MiddlewareCookie. */
326- static class Factory implements FlightServerMiddleware .Factory <MiddlewareCookie > {
330+ static class Factory implements FlightServerMiddleware .Factory <InterceptorMiddleware > {
327331
332+ private final Map <FlightMethod , CallHeaders > receivedCallHeaders = new HashMap <>();
328333 private boolean receivedCookieHeader = false ;
329334 private String cookie ;
330335
331336 @ Override
332- public MiddlewareCookie onCallStarted (
337+ public InterceptorMiddleware onCallStarted (
333338 CallInfo callInfo , CallHeaders callHeaders , RequestContext requestContext ) {
334339 cookie = callHeaders .get ("Cookie" );
335340 receivedCookieHeader = null != cookie ;
336- return new MiddlewareCookie (this );
341+
342+ receivedCallHeaders .put (callInfo .method (), callHeaders );
343+ return new InterceptorMiddleware (this );
337344 }
338345
339346 public String getCookie () {
340347 return cookie ;
341348 }
349+
350+ public String getHeader (FlightMethod method , String key ) {
351+ CallHeaders headers = receivedCallHeaders .get (method );
352+ if (headers == null ) {
353+ return null ;
354+ }
355+ return headers .get (key );
356+ }
342357 }
343358 }
344359}
0 commit comments