Skip to content

Commit 35a12b1

Browse files
committed
[java] reuse the classes created by the WebDriverDecorator #14789
1 parent 4b7d174 commit 35a12b1

File tree

2 files changed

+186
-41
lines changed

2 files changed

+186
-41
lines changed

java/src/org/openqa/selenium/support/decorators/WebDriverDecorator.java

Lines changed: 146 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,23 @@
1717

1818
package org.openqa.selenium.support.decorators;
1919

20+
import java.lang.reflect.Field;
2021
import java.lang.reflect.InvocationHandler;
2122
import java.lang.reflect.InvocationTargetException;
2223
import java.lang.reflect.Method;
2324
import java.util.HashMap;
2425
import java.util.HashSet;
2526
import java.util.List;
2627
import java.util.Map;
28+
import java.util.Objects;
2729
import java.util.Set;
30+
import java.util.concurrent.CompletableFuture;
31+
import java.util.concurrent.ConcurrentHashMap;
32+
import java.util.concurrent.ConcurrentMap;
33+
import java.util.function.Function;
2834
import java.util.stream.Collectors;
2935
import net.bytebuddy.ByteBuddy;
36+
import net.bytebuddy.description.modifier.Visibility;
3037
import net.bytebuddy.dynamic.loading.ClassLoadingStrategy;
3138
import net.bytebuddy.implementation.InvocationHandlerAdapter;
3239
import net.bytebuddy.matcher.ElementMatchers;
@@ -183,6 +190,59 @@
183190
@Beta
184191
public class WebDriverDecorator<T extends WebDriver> {
185192

193+
protected static class Definition {
194+
private final Class<?> decoratedClass;
195+
private final Class<?> originalClass;
196+
197+
public Definition(Decorated<?> decorated) {
198+
this.decoratedClass = decorated.getClass();
199+
this.originalClass = decorated.getOriginal().getClass();
200+
}
201+
202+
@Override
203+
public boolean equals(Object o) {
204+
if (o == null || getClass() != o.getClass()) return false;
205+
Definition definition = (Definition) o;
206+
return (decoratedClass == definition.decoratedClass)
207+
&& (originalClass == definition.originalClass);
208+
}
209+
210+
@Override
211+
public int hashCode() {
212+
return Objects.hash(decoratedClass, originalClass);
213+
}
214+
}
215+
216+
protected static class ProxyFactory<T> {
217+
private final Class<? extends T> clazz;
218+
private final Field targetLookup;
219+
220+
private ProxyFactory(Class<? extends T> clazz, Field targetLookup) {
221+
this.clazz = clazz;
222+
this.targetLookup = targetLookup;
223+
}
224+
225+
public T newInstance(Decorated<T> target) {
226+
T instance;
227+
try {
228+
instance = (T) clazz.newInstance();
229+
} catch (ReflectiveOperationException e) {
230+
throw new AssertionError("Unable to create new proxy", e);
231+
}
232+
233+
try {
234+
// ensure we can later find the target to call
235+
targetLookup.set(instance, target);
236+
} catch (IllegalAccessException ex) {
237+
throw new AssertionError("Failed to set proxy target", ex);
238+
}
239+
240+
return instance;
241+
}
242+
}
243+
244+
private final ConcurrentMap<Definition, ProxyFactory<?>> cache;
245+
186246
private final Class<T> targetWebDriverClass;
187247

188248
private Decorated<T> decorated;
@@ -194,6 +254,7 @@ public WebDriverDecorator() {
194254

195255
public WebDriverDecorator(Class<T> targetClass) {
196256
this.targetWebDriverClass = targetClass;
257+
this.cache = new ConcurrentHashMap<>();
197258
}
198259

199260
public final T decorate(T original) {
@@ -295,18 +356,38 @@ private Object decorateResult(Object toDecorate) {
295356
return toDecorate;
296357
}
297358

298-
protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz) {
299-
Set<Class<?>> decoratedInterfaces = extractInterfaces(decorated);
300-
Set<Class<?>> originalInterfaces = extractInterfaces(decorated.getOriginal());
301-
Map<Class<?>, InvocationHandler> derivedInterfaces =
302-
deriveAdditionalInterfaces(decorated.getOriginal());
359+
protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<? extends Z> clazz) {
360+
@SuppressWarnings("unchecked")
361+
ProxyFactory<Z> factory =
362+
(ProxyFactory<Z>)
363+
cache.computeIfAbsent(
364+
new Definition(decorated), (key) -> createProxyFactory(key, decorated, clazz));
365+
366+
return factory.newInstance(decorated);
367+
}
368+
369+
protected final <Z> ProxyFactory<? extends Z> createProxyFactory(
370+
Definition definition, final Decorated<Z> sample, Class<? extends Z> clazz) {
371+
Set<Class<?>> decoratedInterfaces = extractInterfaces(definition.decoratedClass);
372+
Set<Class<?>> originalInterfaces = extractInterfaces(definition.originalClass);
373+
// all samples with the same definition should have the same derivedInterfaces
374+
Map<Class<?>, Function<Z, InvocationHandler>> derivedInterfaces =
375+
deriveAdditionalInterfaces(sample.getOriginal());
376+
CompletableFuture<Field> proxyLookup = new CompletableFuture<>();
303377

304378
final InvocationHandler handler =
305379
(proxy, method, args) -> {
380+
Field target = proxyLookup.getNow(null);
381+
// Lookup the instance to call, to reuse the clazz and handler.
382+
@SuppressWarnings("unchecked")
383+
Decorated<Z> instance = (Decorated<Z>) target.get(proxy);
384+
if (instance == null) {
385+
throw new AssertionError("Failed to get instance to call");
386+
}
306387
try {
307388
if (method.getDeclaringClass().equals(Object.class)
308389
|| decoratedInterfaces.contains(method.getDeclaringClass())) {
309-
return method.invoke(decorated, args);
390+
return method.invoke(instance, args);
310391
}
311392
// Check if the class in which the method resides, implements any one of the
312393
// interfaces that we extracted from the decorated class.
@@ -317,9 +398,9 @@ protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz)
317398
eachInterface.isAssignableFrom(method.getDeclaringClass()));
318399

319400
if (isCompatible) {
320-
decorated.beforeCall(method, args);
321-
Object result = decorated.call(method, args);
322-
decorated.afterCall(method, result, args);
401+
instance.beforeCall(method, args);
402+
Object result = instance.call(method, args);
403+
instance.afterCall(method, result, args);
323404
return result;
324405
}
325406

@@ -333,12 +414,15 @@ protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz)
333414
eachInterface.isAssignableFrom(method.getDeclaringClass()));
334415

335416
if (isCompatible) {
336-
return derivedInterfaces.get(method.getDeclaringClass()).invoke(proxy, method, args);
417+
return derivedInterfaces
418+
.get(method.getDeclaringClass())
419+
.apply(instance.getOriginal())
420+
.invoke(proxy, method, args);
337421
}
338422

339-
return method.invoke(decorated.getOriginal(), args);
423+
return method.invoke(instance.getOriginal(), args);
340424
} catch (InvocationTargetException e) {
341-
return decorated.onError(method, e, args);
425+
return instance.onError(method, e, args);
342426
}
343427
};
344428

@@ -347,27 +431,45 @@ protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz)
347431
allInterfaces.addAll(originalInterfaces);
348432
allInterfaces.addAll(derivedInterfaces.keySet());
349433
Class<?>[] allInterfacesArray = allInterfaces.toArray(new Class<?>[0]);
434+
// ensure we do not hit an existing field
435+
int fieldName;
436+
437+
for (fieldName = 0; fieldName < 8192; fieldName++) {
438+
try {
439+
clazz.getDeclaredField("___target" + fieldName);
440+
} catch (NoSuchFieldException ex) {
441+
break;
442+
}
443+
}
444+
445+
if (fieldName == 8192) {
446+
throw new AssertionError("No free field name found");
447+
}
350448

351449
Class<? extends Z> proxy =
352450
new ByteBuddy()
353451
.subclass(clazz.isInterface() ? Object.class : clazz)
354452
.implement(allInterfacesArray)
355453
.method(ElementMatchers.any())
356454
.intercept(InvocationHandlerAdapter.of(handler))
455+
.defineField("___target" + fieldName, sample.getClass(), Visibility.PUBLIC)
357456
.make()
358457
.load(clazz.getClassLoader(), ClassLoadingStrategy.Default.WRAPPER)
359458
.getLoaded()
360459
.asSubclass(clazz);
361460

461+
Field lookup;
462+
362463
try {
363-
return proxy.newInstance();
364-
} catch (ReflectiveOperationException e) {
365-
throw new IllegalStateException("Unable to create new proxy", e);
464+
lookup = proxy.getDeclaredField("___target" + fieldName);
465+
} catch (NoSuchFieldException e) {
466+
proxyLookup.completeExceptionally(e);
467+
throw new AssertionError("Defined field is missing", e);
366468
}
367-
}
368469

369-
static Set<Class<?>> extractInterfaces(final Object object) {
370-
return extractInterfaces(object.getClass());
470+
proxyLookup.complete(lookup);
471+
472+
return new ProxyFactory<Z>(proxy, lookup);
371473
}
372474

373475
private static Set<Class<?>> extractInterfaces(final Class<?> clazz) {
@@ -393,43 +495,46 @@ private static void extractInterfaces(final Set<Class<?>> collector, final Class
393495
extractInterfaces(collector, clazz.getSuperclass());
394496
}
395497

396-
private Map<Class<?>, InvocationHandler> deriveAdditionalInterfaces(Object object) {
397-
Map<Class<?>, InvocationHandler> handlers = new HashMap<>();
498+
private <Z> Map<Class<?>, Function<Z, InvocationHandler>> deriveAdditionalInterfaces(Z sample) {
499+
Map<Class<?>, Function<Z, InvocationHandler>> handlers = new HashMap<>();
398500

399-
if (object instanceof WebDriver && !(object instanceof WrapsDriver)) {
501+
if (sample instanceof WebDriver && !(sample instanceof WrapsDriver)) {
400502
handlers.put(
401503
WrapsDriver.class,
402-
(proxy, method, args) -> {
403-
if ("getWrappedDriver".equals(method.getName())) {
404-
return object;
405-
}
406-
throw new UnsupportedOperationException(method.getName());
407-
});
504+
(instance) ->
505+
(proxy, method, args) -> {
506+
if ("getWrappedDriver".equals(method.getName())) {
507+
return instance;
508+
}
509+
throw new UnsupportedOperationException(method.getName());
510+
});
408511
}
409512

410-
if (object instanceof WebElement && !(object instanceof WrapsElement)) {
513+
if (sample instanceof WebElement && !(sample instanceof WrapsElement)) {
411514
handlers.put(
412515
WrapsElement.class,
413-
(proxy, method, args) -> {
414-
if ("getWrappedElement".equals(method.getName())) {
415-
return object;
416-
}
417-
throw new UnsupportedOperationException(method.getName());
418-
});
516+
(instance) ->
517+
(proxy, method, args) -> {
518+
if ("getWrappedElement".equals(method.getName())) {
519+
return instance;
520+
}
521+
throw new UnsupportedOperationException(method.getName());
522+
});
419523
}
420524

421525
try {
422-
Method toJson = object.getClass().getDeclaredMethod("toJson");
526+
Method toJson = sample.getClass().getDeclaredMethod("toJson");
423527
toJson.setAccessible(true);
424528

425529
handlers.put(
426530
JsonSerializer.class,
427-
((proxy, method, args) -> {
428-
if ("toJson".equals(method.getName())) {
429-
return toJson.invoke(object);
430-
}
431-
throw new UnsupportedOperationException(method.getName());
432-
}));
531+
(instance) ->
532+
((proxy, method, args) -> {
533+
if ("toJson".equals(method.getName())) {
534+
return toJson.invoke(instance);
535+
}
536+
throw new UnsupportedOperationException(method.getName());
537+
}));
433538
} catch (NoSuchMethodException e) {
434539
// Fine. Just fall through
435540
}

java/test/org/openqa/selenium/support/decorators/DecoratedWebDriverTest.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import java.util.HashSet;
3232
import java.util.List;
3333
import java.util.Set;
34+
import java.util.UUID;
3435
import java.util.function.Consumer;
3536
import java.util.function.Function;
3637
import org.junit.jupiter.api.Tag;
@@ -163,6 +164,45 @@ void findElement() {
163164
verifyDecoratingFunction($ -> $.findElement(By.id("test")), found, WebElement::click);
164165
}
165166

167+
@Test
168+
void doesNotCreateTooManyClasses() {
169+
final WebElement found = mock(WebElement.class);
170+
Function<WebDriver, WebElement> f = $ -> $.findElement(By.id("test"));
171+
Fixture fixture = new Fixture();
172+
when(f.apply(fixture.original)).thenReturn(found);
173+
174+
WebElement proxy1 = f.apply(fixture.decorated);
175+
WebElement proxy2 = f.apply(fixture.decorated);
176+
WebElement proxy3 = f.apply(fixture.decorated);
177+
178+
assertThat(proxy1.getClass()).isSameAs(proxy2.getClass());
179+
assertThat(proxy1.getClass()).isSameAs(proxy3.getClass());
180+
}
181+
182+
@Test
183+
void doesHitTheCorrectInstance() {
184+
String uuid0 = UUID.randomUUID().toString();
185+
String uuid1 = UUID.randomUUID().toString();
186+
String uuid2 = UUID.randomUUID().toString();
187+
final WebElement found0 = mock(WebElement.class);
188+
final WebElement found1 = mock(WebElement.class);
189+
final WebElement found2 = mock(WebElement.class);
190+
when(found0.getTagName()).thenReturn(uuid0);
191+
when(found1.getTagName()).thenReturn(uuid1);
192+
when(found2.getTagName()).thenReturn(uuid2);
193+
194+
Fixture fixture = new Fixture();
195+
Function<WebDriver, List<WebElement>> f = $ -> $.findElements(By.id("test"));
196+
197+
when(f.apply(fixture.original)).thenReturn(List.of(found0, found1, found2));
198+
199+
List<WebElement> proxies = f.apply(fixture.decorated);
200+
201+
assertThat(proxies.get(0).getTagName()).isEqualTo(uuid0);
202+
assertThat(proxies.get(1).getTagName()).isEqualTo(uuid1);
203+
assertThat(proxies.get(2).getTagName()).isEqualTo(uuid2);
204+
}
205+
166206
@Test
167207
void findElementNotFound() {
168208
Fixture fixture = new Fixture();

0 commit comments

Comments
 (0)