Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 131 additions & 44 deletions java/src/org/openqa/selenium/support/decorators/WebDriverDecorator.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,20 @@
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import net.bytebuddy.ByteBuddy;
import net.bytebuddy.description.modifier.Visibility;
import net.bytebuddy.dynamic.loading.ClassLoadingStrategy;
import net.bytebuddy.implementation.FieldAccessor;
import net.bytebuddy.implementation.InvocationHandlerAdapter;
import net.bytebuddy.matcher.ElementMatchers;
import org.openqa.selenium.Alert;
Expand Down Expand Up @@ -183,6 +189,65 @@
@Beta
public class WebDriverDecorator<T extends WebDriver> {

protected static class Definition {
private final Class<?> decoratedClass;
private final Class<?> originalClass;

public Definition(Decorated<?> decorated) {
this.decoratedClass = decorated.getClass();
this.originalClass = decorated.getOriginal().getClass();
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
Definition definition = (Definition) o;
// intentionally an identity check, to ensure we get no false positive lookup due to an
// unknown implementation of decoratedClass.equals or originalClass.equals
return (decoratedClass == definition.decoratedClass)
&& (originalClass == definition.originalClass);
}

@Override
public int hashCode() {
return Arrays.hashCode(
new int[] {
System.identityHashCode(decoratedClass), System.identityHashCode(originalClass)
});
}
}

public interface HasTarget<Z> {
Decorated<Z> getTarget();

void setTarget(Decorated<Z> target);
}

protected static class ProxyFactory<T> {
private final Class<? extends T> clazz;

private ProxyFactory(Class<? extends T> clazz) {
this.clazz = clazz;
}

public T newInstance(Decorated<T> target) {
T instance;
try {
instance = (T) clazz.newInstance();
} catch (ReflectiveOperationException e) {
throw new AssertionError("Unable to create new proxy", e);
}

// ensure we can later find the target to call
//noinspection unchecked
((HasTarget<T>) instance).setTarget(target);

return instance;
}
}

private final ConcurrentMap<Definition, ProxyFactory<?>> cache;

private final Class<T> targetWebDriverClass;

private Decorated<T> decorated;
Expand All @@ -194,6 +259,7 @@ public WebDriverDecorator() {

public WebDriverDecorator(Class<T> targetClass) {
this.targetWebDriverClass = targetClass;
this.cache = new ConcurrentHashMap<>();
}

public final T decorate(T original) {
Expand Down Expand Up @@ -295,18 +361,36 @@ private Object decorateResult(Object toDecorate) {
return toDecorate;
}

protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz) {
Set<Class<?>> decoratedInterfaces = extractInterfaces(decorated);
Set<Class<?>> originalInterfaces = extractInterfaces(decorated.getOriginal());
Map<Class<?>, InvocationHandler> derivedInterfaces =
deriveAdditionalInterfaces(decorated.getOriginal());
protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<? extends Z> clazz) {
@SuppressWarnings("unchecked")
ProxyFactory<Z> factory =
(ProxyFactory<Z>)
cache.computeIfAbsent(
new Definition(decorated), (key) -> createProxyFactory(key, decorated, clazz));

return factory.newInstance(decorated);
}

protected final <Z> ProxyFactory<? extends Z> createProxyFactory(
Definition definition, final Decorated<Z> sample, Class<? extends Z> clazz) {
Set<Class<?>> decoratedInterfaces = extractInterfaces(definition.decoratedClass);
Set<Class<?>> originalInterfaces = extractInterfaces(definition.originalClass);
// all samples with the same definition should have the same derivedInterfaces
Map<Class<?>, Function<Z, InvocationHandler>> derivedInterfaces =
deriveAdditionalInterfaces(sample.getOriginal());

final InvocationHandler handler =
(proxy, method, args) -> {
// Lookup the instance to call, to reuse the clazz and handler.
@SuppressWarnings("unchecked")
Decorated<Z> instance = ((HasTarget<Z>) proxy).getTarget();
if (instance == null) {
throw new AssertionError("Failed to get instance to call");
}
try {
if (method.getDeclaringClass().equals(Object.class)
|| decoratedInterfaces.contains(method.getDeclaringClass())) {
return method.invoke(decorated, args);
return method.invoke(instance, args);
}
// Check if the class in which the method resides, implements any one of the
// interfaces that we extracted from the decorated class.
Expand All @@ -317,9 +401,9 @@ protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz)
eachInterface.isAssignableFrom(method.getDeclaringClass()));

if (isCompatible) {
decorated.beforeCall(method, args);
Object result = decorated.call(method, args);
decorated.afterCall(method, result, args);
instance.beforeCall(method, args);
Object result = instance.call(method, args);
instance.afterCall(method, result, args);
return result;
}

Expand All @@ -333,19 +417,24 @@ protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz)
eachInterface.isAssignableFrom(method.getDeclaringClass()));

if (isCompatible) {
return derivedInterfaces.get(method.getDeclaringClass()).invoke(proxy, method, args);
return derivedInterfaces
.get(method.getDeclaringClass())
.apply(instance.getOriginal())
.invoke(proxy, method, args);
}

return method.invoke(decorated.getOriginal(), args);
return method.invoke(instance.getOriginal(), args);
} catch (InvocationTargetException e) {
return decorated.onError(method, e, args);
return instance.onError(method, e, args);
}
};

Set<Class<?>> allInterfaces = new HashSet<>();
allInterfaces.addAll(decoratedInterfaces);
allInterfaces.addAll(originalInterfaces);
allInterfaces.addAll(derivedInterfaces.keySet());
// ensure a decorated driver can get decorated again
allInterfaces.remove(HasTarget.class);
Class<?>[] allInterfacesArray = allInterfaces.toArray(new Class<?>[0]);

Class<? extends Z> proxy =
Expand All @@ -354,20 +443,15 @@ protected final <Z> Z createProxy(final Decorated<Z> decorated, Class<Z> clazz)
.implement(allInterfacesArray)
.method(ElementMatchers.any())
.intercept(InvocationHandlerAdapter.of(handler))
.defineField("target", Decorated.class, Visibility.PRIVATE)
.implement(HasTarget.class)
.intercept(FieldAccessor.ofField("target"))
.make()
.load(clazz.getClassLoader(), ClassLoadingStrategy.Default.WRAPPER)
.getLoaded()
.asSubclass(clazz);

try {
return proxy.newInstance();
} catch (ReflectiveOperationException e) {
throw new IllegalStateException("Unable to create new proxy", e);
}
}

static Set<Class<?>> extractInterfaces(final Object object) {
return extractInterfaces(object.getClass());
return new ProxyFactory<Z>(proxy);
}

private static Set<Class<?>> extractInterfaces(final Class<?> clazz) {
Expand All @@ -393,43 +477,46 @@ private static void extractInterfaces(final Set<Class<?>> collector, final Class
extractInterfaces(collector, clazz.getSuperclass());
}

private Map<Class<?>, InvocationHandler> deriveAdditionalInterfaces(Object object) {
Map<Class<?>, InvocationHandler> handlers = new HashMap<>();
private <Z> Map<Class<?>, Function<Z, InvocationHandler>> deriveAdditionalInterfaces(Z sample) {
Map<Class<?>, Function<Z, InvocationHandler>> handlers = new HashMap<>();

if (object instanceof WebDriver && !(object instanceof WrapsDriver)) {
if (sample instanceof WebDriver && !(sample instanceof WrapsDriver)) {
handlers.put(
WrapsDriver.class,
(proxy, method, args) -> {
if ("getWrappedDriver".equals(method.getName())) {
return object;
}
throw new UnsupportedOperationException(method.getName());
});
(instance) ->
(proxy, method, args) -> {
if ("getWrappedDriver".equals(method.getName())) {
return instance;
}
throw new UnsupportedOperationException(method.getName());
});
}

if (object instanceof WebElement && !(object instanceof WrapsElement)) {
if (sample instanceof WebElement && !(sample instanceof WrapsElement)) {
handlers.put(
WrapsElement.class,
(proxy, method, args) -> {
if ("getWrappedElement".equals(method.getName())) {
return object;
}
throw new UnsupportedOperationException(method.getName());
});
(instance) ->
(proxy, method, args) -> {
if ("getWrappedElement".equals(method.getName())) {
return instance;
}
throw new UnsupportedOperationException(method.getName());
});
}

try {
Method toJson = object.getClass().getDeclaredMethod("toJson");
Method toJson = sample.getClass().getDeclaredMethod("toJson");
toJson.setAccessible(true);

handlers.put(
JsonSerializer.class,
((proxy, method, args) -> {
if ("toJson".equals(method.getName())) {
return toJson.invoke(object);
}
throw new UnsupportedOperationException(method.getName());
}));
(instance) ->
((proxy, method, args) -> {
if ("toJson".equals(method.getName())) {
return toJson.invoke(instance);
}
throw new UnsupportedOperationException(method.getName());
}));
} catch (NoSuchMethodException e) {
// Fine. Just fall through
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.UUID;
import java.util.function.Consumer;
import java.util.function.Function;
import org.junit.jupiter.api.Tag;
Expand Down Expand Up @@ -163,6 +164,55 @@ void findElement() {
verifyDecoratingFunction($ -> $.findElement(By.id("test")), found, WebElement::click);
}

@Test
void doesNotCreateTooManyClasses() {
final WebElement found0 = mock(WebElement.class);
final WebElement found1 = mock(WebElement.class);
final WebElement found2 = mock(WebElement.class);
Function<WebDriver, WebElement> f = $ -> $.findElement(By.id("test"));
Function<WebDriver, List<WebElement>> f2 = $ -> $.findElements(By.id("test"));
Fixture fixture = new Fixture();
when(f.apply(fixture.original)).thenReturn(found0);
when(f2.apply(fixture.original)).thenReturn(List.of(found0, found1, found2));

WebElement proxy0 = f.apply(fixture.decorated);
WebElement proxy1 = f.apply(fixture.decorated);
WebElement proxy2 = f.apply(fixture.decorated);

assertThat(proxy0.getClass()).isSameAs(proxy1.getClass());
assertThat(proxy1.getClass()).isSameAs(proxy2.getClass());

List<WebElement> proxies = f2.apply(fixture.decorated);

assertThat(proxy0.getClass()).isSameAs(proxies.get(0).getClass());
assertThat(proxy0.getClass()).isSameAs(proxies.get(1).getClass());
assertThat(proxy0.getClass()).isSameAs(proxies.get(2).getClass());
}

@Test
void doesHitTheCorrectInstance() {
String uuid0 = UUID.randomUUID().toString();
String uuid1 = UUID.randomUUID().toString();
String uuid2 = UUID.randomUUID().toString();
final WebElement found0 = mock(WebElement.class);
final WebElement found1 = mock(WebElement.class);
final WebElement found2 = mock(WebElement.class);
when(found0.getTagName()).thenReturn(uuid0);
when(found1.getTagName()).thenReturn(uuid1);
when(found2.getTagName()).thenReturn(uuid2);

Fixture fixture = new Fixture();
Function<WebDriver, List<WebElement>> f = $ -> $.findElements(By.id("test"));

when(f.apply(fixture.original)).thenReturn(List.of(found0, found1, found2));

List<WebElement> proxies = f.apply(fixture.decorated);

assertThat(proxies.get(0).getTagName()).isEqualTo(uuid0);
assertThat(proxies.get(1).getTagName()).isEqualTo(uuid1);
assertThat(proxies.get(2).getTagName()).isEqualTo(uuid2);
}

@Test
void findElementNotFound() {
Fixture fixture = new Fixture();
Expand Down
Loading