2424import java .util .HashSet ;
2525import java .util .List ;
2626import java .util .Map ;
27+ import java .util .Objects ;
2728import java .util .Set ;
29+ import java .util .concurrent .ConcurrentHashMap ;
30+ import java .util .concurrent .ConcurrentMap ;
31+ import java .util .function .Function ;
2832import java .util .stream .Collectors ;
2933import net .bytebuddy .ByteBuddy ;
34+ import net .bytebuddy .description .modifier .Visibility ;
3035import net .bytebuddy .dynamic .loading .ClassLoadingStrategy ;
36+ import net .bytebuddy .implementation .FieldAccessor ;
3137import net .bytebuddy .implementation .InvocationHandlerAdapter ;
3238import net .bytebuddy .matcher .ElementMatchers ;
3339import org .openqa .selenium .Alert ;
183189@ Beta
184190public class WebDriverDecorator <T extends WebDriver > {
185191
192+ protected static class Definition {
193+ private final Class <?> decoratedClass ;
194+ private final Class <?> originalClass ;
195+
196+ public Definition (Decorated <?> decorated ) {
197+ this .decoratedClass = decorated .getClass ();
198+ this .originalClass = decorated .getOriginal ().getClass ();
199+ }
200+
201+ @ Override
202+ public boolean equals (Object o ) {
203+ if (o == null || getClass () != o .getClass ()) return false ;
204+ Definition definition = (Definition ) o ;
205+ return (decoratedClass == definition .decoratedClass )
206+ && (originalClass == definition .originalClass );
207+ }
208+
209+ @ Override
210+ public int hashCode () {
211+ return Objects .hash (decoratedClass , originalClass );
212+ }
213+ }
214+
215+ public interface HasTarget <Z > {
216+ Decorated <Z > getTarget ();
217+
218+ void setTarget (Decorated <Z > target );
219+ }
220+
221+ protected static class ProxyFactory <T > {
222+ private final Class <? extends T > clazz ;
223+
224+ private ProxyFactory (Class <? extends T > clazz ) {
225+ this .clazz = clazz ;
226+ }
227+
228+ public T newInstance (Decorated <T > target ) {
229+ T instance ;
230+ try {
231+ instance = (T ) clazz .newInstance ();
232+ } catch (ReflectiveOperationException e ) {
233+ throw new AssertionError ("Unable to create new proxy" , e );
234+ }
235+
236+ // ensure we can later find the target to call
237+ //noinspection unchecked
238+ ((HasTarget <T >) instance ).setTarget (target );
239+
240+ return instance ;
241+ }
242+ }
243+
244+ private final ConcurrentMap <Definition , ProxyFactory <?>> cache ;
245+
186246 private final Class <T > targetWebDriverClass ;
187247
188248 private Decorated <T > decorated ;
@@ -194,6 +254,7 @@ public WebDriverDecorator() {
194254
195255 public WebDriverDecorator (Class <T > targetClass ) {
196256 this .targetWebDriverClass = targetClass ;
257+ this .cache = new ConcurrentHashMap <>();
197258 }
198259
199260 public final T decorate (T original ) {
@@ -295,18 +356,36 @@ private Object decorateResult(Object toDecorate) {
295356 return toDecorate ;
296357 }
297358
298- protected final <Z > Z createProxy (final Decorated <Z > decorated , Class <Z > clazz ) {
299- Set <Class <?>> decoratedInterfaces = extractInterfaces (decorated );
300- Set <Class <?>> originalInterfaces = extractInterfaces (decorated .getOriginal ());
301- Map <Class <?>, InvocationHandler > derivedInterfaces =
302- deriveAdditionalInterfaces (decorated .getOriginal ());
359+ protected final <Z > Z createProxy (final Decorated <Z > decorated , Class <? extends Z > clazz ) {
360+ @ SuppressWarnings ("unchecked" )
361+ ProxyFactory <Z > factory =
362+ (ProxyFactory <Z >)
363+ cache .computeIfAbsent (
364+ new Definition (decorated ), (key ) -> createProxyFactory (key , decorated , clazz ));
365+
366+ return factory .newInstance (decorated );
367+ }
368+
369+ protected final <Z > ProxyFactory <? extends Z > createProxyFactory (
370+ Definition definition , final Decorated <Z > sample , Class <? extends Z > clazz ) {
371+ Set <Class <?>> decoratedInterfaces = extractInterfaces (definition .decoratedClass );
372+ Set <Class <?>> originalInterfaces = extractInterfaces (definition .originalClass );
373+ // all samples with the same definition should have the same derivedInterfaces
374+ Map <Class <?>, Function <Z , InvocationHandler >> derivedInterfaces =
375+ deriveAdditionalInterfaces (sample .getOriginal ());
303376
304377 final InvocationHandler handler =
305378 (proxy , method , args ) -> {
379+ // Lookup the instance to call, to reuse the clazz and handler.
380+ @ SuppressWarnings ("unchecked" )
381+ Decorated <Z > instance = ((HasTarget <Z >) proxy ).getTarget ();
382+ if (instance == null ) {
383+ throw new AssertionError ("Failed to get instance to call" );
384+ }
306385 try {
307386 if (method .getDeclaringClass ().equals (Object .class )
308387 || decoratedInterfaces .contains (method .getDeclaringClass ())) {
309- return method .invoke (decorated , args );
388+ return method .invoke (instance , args );
310389 }
311390 // Check if the class in which the method resides, implements any one of the
312391 // interfaces that we extracted from the decorated class.
@@ -317,9 +396,9 @@ protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz)
317396 eachInterface .isAssignableFrom (method .getDeclaringClass ()));
318397
319398 if (isCompatible ) {
320- decorated .beforeCall (method , args );
321- Object result = decorated .call (method , args );
322- decorated .afterCall (method , result , args );
399+ instance .beforeCall (method , args );
400+ Object result = instance .call (method , args );
401+ instance .afterCall (method , result , args );
323402 return result ;
324403 }
325404
@@ -333,19 +412,24 @@ protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz)
333412 eachInterface .isAssignableFrom (method .getDeclaringClass ()));
334413
335414 if (isCompatible ) {
336- return derivedInterfaces .get (method .getDeclaringClass ()).invoke (proxy , method , args );
415+ return derivedInterfaces
416+ .get (method .getDeclaringClass ())
417+ .apply (instance .getOriginal ())
418+ .invoke (proxy , method , args );
337419 }
338420
339- return method .invoke (decorated .getOriginal (), args );
421+ return method .invoke (instance .getOriginal (), args );
340422 } catch (InvocationTargetException e ) {
341- return decorated .onError (method , e , args );
423+ return instance .onError (method , e , args );
342424 }
343425 };
344426
345427 Set <Class <?>> allInterfaces = new HashSet <>();
346428 allInterfaces .addAll (decoratedInterfaces );
347429 allInterfaces .addAll (originalInterfaces );
348430 allInterfaces .addAll (derivedInterfaces .keySet ());
431+ // ensure a decorated driver can get decorated again
432+ allInterfaces .remove (HasTarget .class );
349433 Class <?>[] allInterfacesArray = allInterfaces .toArray (new Class <?>[0 ]);
350434
351435 Class <? extends Z > proxy =
@@ -354,20 +438,15 @@ protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz)
354438 .implement (allInterfacesArray )
355439 .method (ElementMatchers .any ())
356440 .intercept (InvocationHandlerAdapter .of (handler ))
441+ .defineField ("target" , Decorated .class , Visibility .PRIVATE )
442+ .implement (HasTarget .class )
443+ .intercept (FieldAccessor .ofField ("target" ))
357444 .make ()
358445 .load (clazz .getClassLoader (), ClassLoadingStrategy .Default .WRAPPER )
359446 .getLoaded ()
360447 .asSubclass (clazz );
361448
362- try {
363- return proxy .newInstance ();
364- } catch (ReflectiveOperationException e ) {
365- throw new IllegalStateException ("Unable to create new proxy" , e );
366- }
367- }
368-
369- static Set <Class <?>> extractInterfaces (final Object object ) {
370- return extractInterfaces (object .getClass ());
449+ return new ProxyFactory <Z >(proxy );
371450 }
372451
373452 private static Set <Class <?>> extractInterfaces (final Class <?> clazz ) {
@@ -393,43 +472,46 @@ private static void extractInterfaces(final Set<Class<?>> collector, final Class
393472 extractInterfaces (collector , clazz .getSuperclass ());
394473 }
395474
396- private Map <Class <?>, InvocationHandler > deriveAdditionalInterfaces (Object object ) {
397- Map <Class <?>, InvocationHandler > handlers = new HashMap <>();
475+ private < Z > Map <Class <?>, Function < Z , InvocationHandler >> deriveAdditionalInterfaces (Z sample ) {
476+ Map <Class <?>, Function < Z , InvocationHandler > > handlers = new HashMap <>();
398477
399- if (object instanceof WebDriver && !(object instanceof WrapsDriver )) {
478+ if (sample instanceof WebDriver && !(sample instanceof WrapsDriver )) {
400479 handlers .put (
401480 WrapsDriver .class ,
402- (proxy , method , args ) -> {
403- if ("getWrappedDriver" .equals (method .getName ())) {
404- return object ;
405- }
406- throw new UnsupportedOperationException (method .getName ());
407- });
481+ (instance ) ->
482+ (proxy , method , args ) -> {
483+ if ("getWrappedDriver" .equals (method .getName ())) {
484+ return instance ;
485+ }
486+ throw new UnsupportedOperationException (method .getName ());
487+ });
408488 }
409489
410- if (object instanceof WebElement && !(object instanceof WrapsElement )) {
490+ if (sample instanceof WebElement && !(sample instanceof WrapsElement )) {
411491 handlers .put (
412492 WrapsElement .class ,
413- (proxy , method , args ) -> {
414- if ("getWrappedElement" .equals (method .getName ())) {
415- return object ;
416- }
417- throw new UnsupportedOperationException (method .getName ());
418- });
493+ (instance ) ->
494+ (proxy , method , args ) -> {
495+ if ("getWrappedElement" .equals (method .getName ())) {
496+ return instance ;
497+ }
498+ throw new UnsupportedOperationException (method .getName ());
499+ });
419500 }
420501
421502 try {
422- Method toJson = object .getClass ().getDeclaredMethod ("toJson" );
503+ Method toJson = sample .getClass ().getDeclaredMethod ("toJson" );
423504 toJson .setAccessible (true );
424505
425506 handlers .put (
426507 JsonSerializer .class ,
427- ((proxy , method , args ) -> {
428- if ("toJson" .equals (method .getName ())) {
429- return toJson .invoke (object );
430- }
431- throw new UnsupportedOperationException (method .getName ());
432- }));
508+ (instance ) ->
509+ ((proxy , method , args ) -> {
510+ if ("toJson" .equals (method .getName ())) {
511+ return toJson .invoke (instance );
512+ }
513+ throw new UnsupportedOperationException (method .getName ());
514+ }));
433515 } catch (NoSuchMethodException e ) {
434516 // Fine. Just fall through
435517 }
0 commit comments