Skip to content

Commit 906a5b3

Browse files
committed
[java] reuse the classes created by the WebDriverDecorator #14789
1 parent cfaa8c4 commit 906a5b3

File tree

2 files changed

+166
-44
lines changed

2 files changed

+166
-44
lines changed

java/src/org/openqa/selenium/support/decorators/WebDriverDecorator.java

Lines changed: 126 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,16 @@
2424
import java.util.HashSet;
2525
import java.util.List;
2626
import java.util.Map;
27+
import java.util.Objects;
2728
import java.util.Set;
29+
import java.util.concurrent.ConcurrentHashMap;
30+
import java.util.concurrent.ConcurrentMap;
31+
import java.util.function.Function;
2832
import java.util.stream.Collectors;
2933
import net.bytebuddy.ByteBuddy;
34+
import net.bytebuddy.description.modifier.Visibility;
3035
import net.bytebuddy.dynamic.loading.ClassLoadingStrategy;
36+
import net.bytebuddy.implementation.FieldAccessor;
3137
import net.bytebuddy.implementation.InvocationHandlerAdapter;
3238
import net.bytebuddy.matcher.ElementMatchers;
3339
import org.openqa.selenium.Alert;
@@ -183,6 +189,60 @@
183189
@Beta
184190
public class WebDriverDecorator<T extends WebDriver> {
185191

192+
protected static class Definition {
193+
private final Class<?> decoratedClass;
194+
private final Class<?> originalClass;
195+
196+
public Definition(Decorated<?> decorated) {
197+
this.decoratedClass = decorated.getClass();
198+
this.originalClass = decorated.getOriginal().getClass();
199+
}
200+
201+
@Override
202+
public boolean equals(Object o) {
203+
if (o == null || getClass() != o.getClass()) return false;
204+
Definition definition = (Definition) o;
205+
return (decoratedClass == definition.decoratedClass)
206+
&& (originalClass == definition.originalClass);
207+
}
208+
209+
@Override
210+
public int hashCode() {
211+
return Objects.hash(decoratedClass, originalClass);
212+
}
213+
}
214+
215+
public interface HasTarget<Z> {
216+
Decorated<Z> getTarget();
217+
218+
void setTarget(Decorated<Z> target);
219+
}
220+
221+
protected static class ProxyFactory<T> {
222+
private final Class<? extends T> clazz;
223+
224+
private ProxyFactory(Class<? extends T> clazz) {
225+
this.clazz = clazz;
226+
}
227+
228+
public T newInstance(Decorated<T> target) {
229+
T instance;
230+
try {
231+
instance = (T) clazz.newInstance();
232+
} catch (ReflectiveOperationException e) {
233+
throw new AssertionError("Unable to create new proxy", e);
234+
}
235+
236+
// ensure we can later find the target to call
237+
//noinspection unchecked
238+
((HasTarget<T>) instance).setTarget(target);
239+
240+
return instance;
241+
}
242+
}
243+
244+
private final ConcurrentMap<Definition, ProxyFactory<?>> cache;
245+
186246
private final Class<T> targetWebDriverClass;
187247

188248
private Decorated<T> decorated;
@@ -194,6 +254,7 @@ public WebDriverDecorator() {
194254

195255
public WebDriverDecorator(Class<T> targetClass) {
196256
this.targetWebDriverClass = targetClass;
257+
this.cache = new ConcurrentHashMap<>();
197258
}
198259

199260
public final T decorate(T original) {
@@ -295,18 +356,36 @@ private Object decorateResult(Object toDecorate) {
295356
return toDecorate;
296357
}
297358

298-
protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz) {
299-
Set<Class<?>> decoratedInterfaces = extractInterfaces(decorated);
300-
Set<Class<?>> originalInterfaces = extractInterfaces(decorated.getOriginal());
301-
Map<Class<?>, InvocationHandler> derivedInterfaces =
302-
deriveAdditionalInterfaces(decorated.getOriginal());
359+
protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<? extends Z> clazz) {
360+
@SuppressWarnings("unchecked")
361+
ProxyFactory<Z> factory =
362+
(ProxyFactory<Z>)
363+
cache.computeIfAbsent(
364+
new Definition(decorated), (key) -> createProxyFactory(key, decorated, clazz));
365+
366+
return factory.newInstance(decorated);
367+
}
368+
369+
protected final <Z> ProxyFactory<? extends Z> createProxyFactory(
370+
Definition definition, final Decorated<Z> sample, Class<? extends Z> clazz) {
371+
Set<Class<?>> decoratedInterfaces = extractInterfaces(definition.decoratedClass);
372+
Set<Class<?>> originalInterfaces = extractInterfaces(definition.originalClass);
373+
// all samples with the same definition should have the same derivedInterfaces
374+
Map<Class<?>, Function<Z, InvocationHandler>> derivedInterfaces =
375+
deriveAdditionalInterfaces(sample.getOriginal());
303376

304377
final InvocationHandler handler =
305378
(proxy, method, args) -> {
379+
// Lookup the instance to call, to reuse the clazz and handler.
380+
@SuppressWarnings("unchecked")
381+
Decorated<Z> instance = ((HasTarget<Z>) proxy).getTarget();
382+
if (instance == null) {
383+
throw new AssertionError("Failed to get instance to call");
384+
}
306385
try {
307386
if (method.getDeclaringClass().equals(Object.class)
308387
|| decoratedInterfaces.contains(method.getDeclaringClass())) {
309-
return method.invoke(decorated, args);
388+
return method.invoke(instance, args);
310389
}
311390
// Check if the class in which the method resides, implements any one of the
312391
// interfaces that we extracted from the decorated class.
@@ -317,9 +396,9 @@ protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz)
317396
eachInterface.isAssignableFrom(method.getDeclaringClass()));
318397

319398
if (isCompatible) {
320-
decorated.beforeCall(method, args);
321-
Object result = decorated.call(method, args);
322-
decorated.afterCall(method, result, args);
399+
instance.beforeCall(method, args);
400+
Object result = instance.call(method, args);
401+
instance.afterCall(method, result, args);
323402
return result;
324403
}
325404

@@ -333,19 +412,24 @@ protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz)
333412
eachInterface.isAssignableFrom(method.getDeclaringClass()));
334413

335414
if (isCompatible) {
336-
return derivedInterfaces.get(method.getDeclaringClass()).invoke(proxy, method, args);
415+
return derivedInterfaces
416+
.get(method.getDeclaringClass())
417+
.apply(instance.getOriginal())
418+
.invoke(proxy, method, args);
337419
}
338420

339-
return method.invoke(decorated.getOriginal(), args);
421+
return method.invoke(instance.getOriginal(), args);
340422
} catch (InvocationTargetException e) {
341-
return decorated.onError(method, e, args);
423+
return instance.onError(method, e, args);
342424
}
343425
};
344426

345427
Set<Class<?>> allInterfaces = new HashSet<>();
346428
allInterfaces.addAll(decoratedInterfaces);
347429
allInterfaces.addAll(originalInterfaces);
348430
allInterfaces.addAll(derivedInterfaces.keySet());
431+
// ensure a decorated driver can get decorated again
432+
allInterfaces.remove(HasTarget.class);
349433
Class<?>[] allInterfacesArray = allInterfaces.toArray(new Class<?>[0]);
350434

351435
Class<? extends Z> proxy =
@@ -354,20 +438,15 @@ protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz)
354438
.implement(allInterfacesArray)
355439
.method(ElementMatchers.any())
356440
.intercept(InvocationHandlerAdapter.of(handler))
441+
.defineField("target", Decorated.class, Visibility.PRIVATE)
442+
.implement(HasTarget.class)
443+
.intercept(FieldAccessor.ofField("target"))
357444
.make()
358445
.load(clazz.getClassLoader(), ClassLoadingStrategy.Default.WRAPPER)
359446
.getLoaded()
360447
.asSubclass(clazz);
361448

362-
try {
363-
return proxy.newInstance();
364-
} catch (ReflectiveOperationException e) {
365-
throw new IllegalStateException("Unable to create new proxy", e);
366-
}
367-
}
368-
369-
static Set<Class<?>> extractInterfaces(final Object object) {
370-
return extractInterfaces(object.getClass());
449+
return new ProxyFactory<Z>(proxy);
371450
}
372451

373452
private static Set<Class<?>> extractInterfaces(final Class<?> clazz) {
@@ -393,43 +472,46 @@ private static void extractInterfaces(final Set<Class<?>> collector, final Class
393472
extractInterfaces(collector, clazz.getSuperclass());
394473
}
395474

396-
private Map<Class<?>, InvocationHandler> deriveAdditionalInterfaces(Object object) {
397-
Map<Class<?>, InvocationHandler> handlers = new HashMap<>();
475+
private <Z> Map<Class<?>, Function<Z, InvocationHandler>> deriveAdditionalInterfaces(Z sample) {
476+
Map<Class<?>, Function<Z, InvocationHandler>> handlers = new HashMap<>();
398477

399-
if (object instanceof WebDriver && !(object instanceof WrapsDriver)) {
478+
if (sample instanceof WebDriver && !(sample instanceof WrapsDriver)) {
400479
handlers.put(
401480
WrapsDriver.class,
402-
(proxy, method, args) -> {
403-
if ("getWrappedDriver".equals(method.getName())) {
404-
return object;
405-
}
406-
throw new UnsupportedOperationException(method.getName());
407-
});
481+
(instance) ->
482+
(proxy, method, args) -> {
483+
if ("getWrappedDriver".equals(method.getName())) {
484+
return instance;
485+
}
486+
throw new UnsupportedOperationException(method.getName());
487+
});
408488
}
409489

410-
if (object instanceof WebElement && !(object instanceof WrapsElement)) {
490+
if (sample instanceof WebElement && !(sample instanceof WrapsElement)) {
411491
handlers.put(
412492
WrapsElement.class,
413-
(proxy, method, args) -> {
414-
if ("getWrappedElement".equals(method.getName())) {
415-
return object;
416-
}
417-
throw new UnsupportedOperationException(method.getName());
418-
});
493+
(instance) ->
494+
(proxy, method, args) -> {
495+
if ("getWrappedElement".equals(method.getName())) {
496+
return instance;
497+
}
498+
throw new UnsupportedOperationException(method.getName());
499+
});
419500
}
420501

421502
try {
422-
Method toJson = object.getClass().getDeclaredMethod("toJson");
503+
Method toJson = sample.getClass().getDeclaredMethod("toJson");
423504
toJson.setAccessible(true);
424505

425506
handlers.put(
426507
JsonSerializer.class,
427-
((proxy, method, args) -> {
428-
if ("toJson".equals(method.getName())) {
429-
return toJson.invoke(object);
430-
}
431-
throw new UnsupportedOperationException(method.getName());
432-
}));
508+
(instance) ->
509+
((proxy, method, args) -> {
510+
if ("toJson".equals(method.getName())) {
511+
return toJson.invoke(instance);
512+
}
513+
throw new UnsupportedOperationException(method.getName());
514+
}));
433515
} catch (NoSuchMethodException e) {
434516
// Fine. Just fall through
435517
}

java/test/org/openqa/selenium/support/decorators/DecoratedWebDriverTest.java

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import java.util.HashSet;
3232
import java.util.List;
3333
import java.util.Set;
34+
import java.util.UUID;
3435
import java.util.function.Consumer;
3536
import java.util.function.Function;
3637
import org.junit.jupiter.api.Tag;
@@ -163,6 +164,45 @@ void findElement() {
163164
verifyDecoratingFunction($ -> $.findElement(By.id("test")), found, WebElement::click);
164165
}
165166

167+
@Test
168+
void doesNotCreateTooManyClasses() {
169+
final WebElement found = mock(WebElement.class);
170+
Function<WebDriver, WebElement> f = $ -> $.findElement(By.id("test"));
171+
Fixture fixture = new Fixture();
172+
when(f.apply(fixture.original)).thenReturn(found);
173+
174+
WebElement proxy1 = f.apply(fixture.decorated);
175+
WebElement proxy2 = f.apply(fixture.decorated);
176+
WebElement proxy3 = f.apply(fixture.decorated);
177+
178+
assertThat(proxy1.getClass()).isSameAs(proxy2.getClass());
179+
assertThat(proxy1.getClass()).isSameAs(proxy3.getClass());
180+
}
181+
182+
@Test
183+
void doesHitTheCorrectInstance() {
184+
String uuid0 = UUID.randomUUID().toString();
185+
String uuid1 = UUID.randomUUID().toString();
186+
String uuid2 = UUID.randomUUID().toString();
187+
final WebElement found0 = mock(WebElement.class);
188+
final WebElement found1 = mock(WebElement.class);
189+
final WebElement found2 = mock(WebElement.class);
190+
when(found0.getTagName()).thenReturn(uuid0);
191+
when(found1.getTagName()).thenReturn(uuid1);
192+
when(found2.getTagName()).thenReturn(uuid2);
193+
194+
Fixture fixture = new Fixture();
195+
Function<WebDriver, List<WebElement>> f = $ -> $.findElements(By.id("test"));
196+
197+
when(f.apply(fixture.original)).thenReturn(List.of(found0, found1, found2));
198+
199+
List<WebElement> proxies = f.apply(fixture.decorated);
200+
201+
assertThat(proxies.get(0).getTagName()).isEqualTo(uuid0);
202+
assertThat(proxies.get(1).getTagName()).isEqualTo(uuid1);
203+
assertThat(proxies.get(2).getTagName()).isEqualTo(uuid2);
204+
}
205+
166206
@Test
167207
void findElementNotFound() {
168208
Fixture fixture = new Fixture();

0 commit comments

Comments
 (0)