2020import  java .lang .reflect .InvocationHandler ;
2121import  java .lang .reflect .InvocationTargetException ;
2222import  java .lang .reflect .Method ;
23+ import  java .util .Arrays ;
2324import  java .util .HashMap ;
2425import  java .util .HashSet ;
2526import  java .util .List ;
2627import  java .util .Map ;
2728import  java .util .Set ;
29+ import  java .util .concurrent .ConcurrentHashMap ;
30+ import  java .util .concurrent .ConcurrentMap ;
31+ import  java .util .function .Function ;
2832import  java .util .stream .Collectors ;
2933import  net .bytebuddy .ByteBuddy ;
34+ import  net .bytebuddy .description .modifier .Visibility ;
3035import  net .bytebuddy .dynamic .loading .ClassLoadingStrategy ;
36+ import  net .bytebuddy .implementation .FieldAccessor ;
3137import  net .bytebuddy .implementation .InvocationHandlerAdapter ;
3238import  net .bytebuddy .matcher .ElementMatchers ;
3339import  org .openqa .selenium .Alert ;
183189@ Beta 
184190public  class  WebDriverDecorator <T  extends  WebDriver > {
185191
192+   protected  static  class  Definition  {
193+     private  final  Class <?> decoratedClass ;
194+     private  final  Class <?> originalClass ;
195+ 
196+     public  Definition (Decorated <?> decorated ) {
197+       this .decoratedClass  = decorated .getClass ();
198+       this .originalClass  = decorated .getOriginal ().getClass ();
199+     }
200+ 
201+     @ Override 
202+     public  boolean  equals (Object  o ) {
203+       if  (o  == null  || getClass () != o .getClass ()) return  false ;
204+       Definition  definition  = (Definition ) o ;
205+       // intentionally an identity check, to ensure we get no false positive lookup due to an 
206+       // unknown implementation of decoratedClass.equals or originalClass.equals 
207+       return  (decoratedClass  == definition .decoratedClass )
208+           && (originalClass  == definition .originalClass );
209+     }
210+ 
211+     @ Override 
212+     public  int  hashCode () {
213+       return  Arrays .hashCode (
214+           new  int [] {
215+             System .identityHashCode (decoratedClass ), System .identityHashCode (originalClass )
216+           });
217+     }
218+   }
219+ 
220+   public  interface  HasTarget <Z > {
221+     Decorated <Z > getTarget ();
222+ 
223+     void  setTarget (Decorated <Z > target );
224+   }
225+ 
226+   protected  static  class  ProxyFactory <T > {
227+     private  final  Class <? extends  T > clazz ;
228+ 
229+     private  ProxyFactory (Class <? extends  T > clazz ) {
230+       this .clazz  = clazz ;
231+     }
232+ 
233+     public  T  newInstance (Decorated <T > target ) {
234+       T  instance ;
235+       try  {
236+         instance  = (T ) clazz .newInstance ();
237+       } catch  (ReflectiveOperationException  e ) {
238+         throw  new  AssertionError ("Unable to create new proxy" , e );
239+       }
240+ 
241+       // ensure we can later find the target to call 
242+       //noinspection unchecked 
243+       ((HasTarget <T >) instance ).setTarget (target );
244+ 
245+       return  instance ;
246+     }
247+   }
248+ 
249+   private  final  ConcurrentMap <Definition , ProxyFactory <?>> cache ;
250+ 
186251  private  final  Class <T > targetWebDriverClass ;
187252
188253  private  Decorated <T > decorated ;
@@ -194,6 +259,7 @@ public WebDriverDecorator() {
194259
195260  public  WebDriverDecorator (Class <T > targetClass ) {
196261    this .targetWebDriverClass  = targetClass ;
262+     this .cache  = new  ConcurrentHashMap <>();
197263  }
198264
199265  public  final  T  decorate (T  original ) {
@@ -295,18 +361,36 @@ private Object decorateResult(Object toDecorate) {
295361    return  toDecorate ;
296362  }
297363
298-   protected  final  <Z > Z  createProxy (final  Decorated <Z > decorated , Class <Z > clazz ) {
299-     Set <Class <?>> decoratedInterfaces  = extractInterfaces (decorated );
300-     Set <Class <?>> originalInterfaces  = extractInterfaces (decorated .getOriginal ());
301-     Map <Class <?>, InvocationHandler > derivedInterfaces  =
302-         deriveAdditionalInterfaces (decorated .getOriginal ());
364+   protected  final  <Z > Z  createProxy (final  Decorated <Z > decorated , Class <? extends  Z > clazz ) {
365+     @ SuppressWarnings ("unchecked" )
366+     ProxyFactory <Z > factory  =
367+         (ProxyFactory <Z >)
368+             cache .computeIfAbsent (
369+                 new  Definition (decorated ), (key ) -> createProxyFactory (key , decorated , clazz ));
370+ 
371+     return  factory .newInstance (decorated );
372+   }
373+ 
374+   protected  final  <Z > ProxyFactory <? extends  Z > createProxyFactory (
375+       Definition  definition , final  Decorated <Z > sample , Class <? extends  Z > clazz ) {
376+     Set <Class <?>> decoratedInterfaces  = extractInterfaces (definition .decoratedClass );
377+     Set <Class <?>> originalInterfaces  = extractInterfaces (definition .originalClass );
378+     // all samples with the same definition should have the same derivedInterfaces 
379+     Map <Class <?>, Function <Z , InvocationHandler >> derivedInterfaces  =
380+         deriveAdditionalInterfaces (sample .getOriginal ());
303381
304382    final  InvocationHandler  handler  =
305383        (proxy , method , args ) -> {
384+           // Lookup the instance to call, to reuse the clazz and handler. 
385+           @ SuppressWarnings ("unchecked" )
386+           Decorated <Z > instance  = ((HasTarget <Z >) proxy ).getTarget ();
387+           if  (instance  == null ) {
388+             throw  new  AssertionError ("Failed to get instance to call" );
389+           }
306390          try  {
307391            if  (method .getDeclaringClass ().equals (Object .class )
308392                || decoratedInterfaces .contains (method .getDeclaringClass ())) {
309-               return  method .invoke (decorated , args );
393+               return  method .invoke (instance , args );
310394            }
311395            // Check if the class in which the method resides, implements any one of the 
312396            // interfaces that we extracted from the decorated class. 
@@ -317,9 +401,9 @@ protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz)
317401                            eachInterface .isAssignableFrom (method .getDeclaringClass ()));
318402
319403            if  (isCompatible ) {
320-               decorated .beforeCall (method , args );
321-               Object  result  = decorated .call (method , args );
322-               decorated .afterCall (method , result , args );
404+               instance .beforeCall (method , args );
405+               Object  result  = instance .call (method , args );
406+               instance .afterCall (method , result , args );
323407              return  result ;
324408            }
325409
@@ -333,19 +417,24 @@ protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz)
333417                            eachInterface .isAssignableFrom (method .getDeclaringClass ()));
334418
335419            if  (isCompatible ) {
336-               return  derivedInterfaces .get (method .getDeclaringClass ()).invoke (proxy , method , args );
420+               return  derivedInterfaces 
421+                   .get (method .getDeclaringClass ())
422+                   .apply (instance .getOriginal ())
423+                   .invoke (proxy , method , args );
337424            }
338425
339-             return  method .invoke (decorated .getOriginal (), args );
426+             return  method .invoke (instance .getOriginal (), args );
340427          } catch  (InvocationTargetException  e ) {
341-             return  decorated .onError (method , e , args );
428+             return  instance .onError (method , e , args );
342429          }
343430        };
344431
345432    Set <Class <?>> allInterfaces  = new  HashSet <>();
346433    allInterfaces .addAll (decoratedInterfaces );
347434    allInterfaces .addAll (originalInterfaces );
348435    allInterfaces .addAll (derivedInterfaces .keySet ());
436+     // ensure a decorated driver can get decorated again 
437+     allInterfaces .remove (HasTarget .class );
349438    Class <?>[] allInterfacesArray  = allInterfaces .toArray (new  Class <?>[0 ]);
350439
351440    Class <? extends  Z > proxy  =
@@ -354,20 +443,15 @@ protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz)
354443            .implement (allInterfacesArray )
355444            .method (ElementMatchers .any ())
356445            .intercept (InvocationHandlerAdapter .of (handler ))
446+             .defineField ("target" , Decorated .class , Visibility .PRIVATE )
447+             .implement (HasTarget .class )
448+             .intercept (FieldAccessor .ofField ("target" ))
357449            .make ()
358450            .load (clazz .getClassLoader (), ClassLoadingStrategy .Default .WRAPPER )
359451            .getLoaded ()
360452            .asSubclass (clazz );
361453
362-     try  {
363-       return  proxy .newInstance ();
364-     } catch  (ReflectiveOperationException  e ) {
365-       throw  new  IllegalStateException ("Unable to create new proxy" , e );
366-     }
367-   }
368- 
369-   static  Set <Class <?>> extractInterfaces (final  Object  object ) {
370-     return  extractInterfaces (object .getClass ());
454+     return  new  ProxyFactory <Z >(proxy );
371455  }
372456
373457  private  static  Set <Class <?>> extractInterfaces (final  Class <?> clazz ) {
@@ -393,43 +477,46 @@ private static void extractInterfaces(final Set<Class<?>> collector, final Class
393477    extractInterfaces (collector , clazz .getSuperclass ());
394478  }
395479
396-   private  Map <Class <?>, InvocationHandler > deriveAdditionalInterfaces (Object   object ) {
397-     Map <Class <?>, InvocationHandler > handlers  = new  HashMap <>();
480+   private  < Z >  Map <Class <?>, Function < Z ,  InvocationHandler >>  deriveAdditionalInterfaces (Z   sample ) {
481+     Map <Class <?>, Function < Z ,  InvocationHandler > > handlers  = new  HashMap <>();
398482
399-     if  (object  instanceof  WebDriver  && !(object  instanceof  WrapsDriver )) {
483+     if  (sample  instanceof  WebDriver  && !(sample  instanceof  WrapsDriver )) {
400484      handlers .put (
401485          WrapsDriver .class ,
402-           (proxy , method , args ) -> {
403-             if  ("getWrappedDriver" .equals (method .getName ())) {
404-               return  object ;
405-             }
406-             throw  new  UnsupportedOperationException (method .getName ());
407-           });
486+           (instance ) ->
487+               (proxy , method , args ) -> {
488+                 if  ("getWrappedDriver" .equals (method .getName ())) {
489+                   return  instance ;
490+                 }
491+                 throw  new  UnsupportedOperationException (method .getName ());
492+               });
408493    }
409494
410-     if  (object  instanceof  WebElement  && !(object  instanceof  WrapsElement )) {
495+     if  (sample  instanceof  WebElement  && !(sample  instanceof  WrapsElement )) {
411496      handlers .put (
412497          WrapsElement .class ,
413-           (proxy , method , args ) -> {
414-             if  ("getWrappedElement" .equals (method .getName ())) {
415-               return  object ;
416-             }
417-             throw  new  UnsupportedOperationException (method .getName ());
418-           });
498+           (instance ) ->
499+               (proxy , method , args ) -> {
500+                 if  ("getWrappedElement" .equals (method .getName ())) {
501+                   return  instance ;
502+                 }
503+                 throw  new  UnsupportedOperationException (method .getName ());
504+               });
419505    }
420506
421507    try  {
422-       Method  toJson  = object .getClass ().getDeclaredMethod ("toJson" );
508+       Method  toJson  = sample .getClass ().getDeclaredMethod ("toJson" );
423509      toJson .setAccessible (true );
424510
425511      handlers .put (
426512          JsonSerializer .class ,
427-           ((proxy , method , args ) -> {
428-             if  ("toJson" .equals (method .getName ())) {
429-               return  toJson .invoke (object );
430-             }
431-             throw  new  UnsupportedOperationException (method .getName ());
432-           }));
513+           (instance ) ->
514+               ((proxy , method , args ) -> {
515+                 if  ("toJson" .equals (method .getName ())) {
516+                   return  toJson .invoke (instance );
517+                 }
518+                 throw  new  UnsupportedOperationException (method .getName ());
519+               }));
433520    } catch  (NoSuchMethodException  e ) {
434521      // Fine. Just fall through 
435522    }
0 commit comments