|
10 | 10 | import java.lang.annotation.RetentionPolicy; |
11 | 11 | import java.lang.annotation.Target; |
12 | 12 | import java.lang.invoke.*; |
| 13 | +import java.lang.ref.WeakReference; |
| 14 | +import java.util.HashMap; |
| 15 | +import java.util.Map; |
13 | 16 | import java.util.Objects; |
| 17 | +import java.util.WeakHashMap; |
14 | 18 | import java.util.concurrent.ConcurrentHashMap; |
15 | 19 | import java.util.concurrent.ConcurrentMap; |
| 20 | +import java.util.concurrent.locks.Lock; |
| 21 | +import java.util.concurrent.locks.ReentrantLock; |
16 | 22 | import java.util.function.BinaryOperator; |
17 | 23 |
|
18 | 24 | public final class LambdaFactory { |
19 | 25 | private static final class LambdaClassLoader extends ClassLoader { |
| 26 | + private static final Map<ClassLoader, WeakReference<LambdaClassLoader>> KNOWN_LOADERS = new WeakHashMap<>(); |
| 27 | + |
20 | 28 | static { |
21 | 29 | ClassLoader.registerAsParallelCapable(); |
22 | 30 | } |
23 | 31 |
|
24 | | - private final String name; |
25 | | - private final byte[] data; |
| 32 | + private final Lock lock; |
| 33 | + private final Map<String, byte[]> knownLambdas; |
26 | 34 |
|
27 | | - private LambdaClassLoader(final String name, final byte[] data, final ClassLoader parent) { |
| 35 | + private LambdaClassLoader(final ClassLoader parent) { |
28 | 36 | super(parent); |
29 | | - this.name = name; |
30 | | - this.data = data; |
| 37 | + this.lock = new ReentrantLock(); |
| 38 | + this.knownLambdas = new HashMap<>(); |
| 39 | + } |
| 40 | + |
| 41 | + private static LambdaClassLoader findLoader(final MethodHandles.Lookup lookup) { |
| 42 | + final ClassLoader lookupClassLoader = lookup.lookupClass().getClassLoader(); |
| 43 | + final WeakReference<LambdaClassLoader> loaderRef = KNOWN_LOADERS.get(lookupClassLoader); |
| 44 | + if (loaderRef == null || loaderRef.get() == null) { |
| 45 | + // LambdaClassLoader reference was lost (or it never existed), so recreate |
| 46 | + final LambdaClassLoader loader = new LambdaClassLoader(lookupClassLoader); |
| 47 | + KNOWN_LOADERS.put(loader, new WeakReference<>(loader)); |
| 48 | + return loader; |
| 49 | + } |
| 50 | + return loaderRef.get(); |
31 | 51 | } |
32 | 52 |
|
33 | | - static ClassLoader spinLoader(final String name, final byte[] data, final MethodHandles.Lookup lookup) { |
34 | | - return new LambdaClassLoader(name, data, lookup.lookupClass().getClassLoader()); |
| 53 | + LambdaClassLoader registerLambda(final String name, final byte[] data) { |
| 54 | + this.lock.lock(); |
| 55 | + try { |
| 56 | + if (this.knownLambdas.containsKey(name)) { |
| 57 | + throw new IllegalStateException("Lambda with name '" + name + "' already exists"); |
| 58 | + } |
| 59 | + this.knownLambdas.put(name, data); |
| 60 | + } finally { |
| 61 | + this.lock.unlock(); |
| 62 | + } |
| 63 | + return this; |
35 | 64 | } |
36 | 65 |
|
37 | 66 | @Override |
38 | 67 | protected Class<?> findClass(final String name) throws ClassNotFoundException { |
39 | | - if (this.name.equals(name)) { |
40 | | - return this.defineClass(name, this.data, 0, this.data.length); |
| 68 | + // Step 1: try loading the class directly without locking; there should never be an instance where a thread |
| 69 | + // is trying to load a class that is being registered on another thread |
| 70 | + Class<?> lambdaClass = this.tryLoadLambdaClass(name); |
| 71 | + if (lambdaClass != null) { |
| 72 | + return lambdaClass; |
41 | 73 | } |
| 74 | + |
| 75 | + // Step 2: if the previous step failed, let's retry with locking just in case the above situation happened |
| 76 | + this.lock.lock(); |
| 77 | + try { |
| 78 | + lambdaClass = this.tryLoadLambdaClass(name); |
| 79 | + if (lambdaClass != null) { |
| 80 | + return lambdaClass; |
| 81 | + } |
| 82 | + } finally { |
| 83 | + this.lock.unlock(); |
| 84 | + } |
| 85 | + |
| 86 | + // Step 3: Defer to default behavior |
42 | 87 | return super.findClass(name); |
43 | 88 | } |
44 | 89 |
|
45 | | - |
| 90 | + private Class<?> tryLoadLambdaClass(final String name) { |
| 91 | + final byte[] lambdaBytes = this.knownLambdas.get(name); |
| 92 | + if (lambdaBytes != null) { |
| 93 | + return this.defineClass(name, lambdaBytes, 0, lambdaBytes.length); |
| 94 | + } |
| 95 | + return null; |
| 96 | + } |
46 | 97 | } |
47 | 98 |
|
48 | 99 | private static final class LambdaCounters { |
@@ -497,7 +548,8 @@ private static Class<?> generateInterfaceImplementation( |
497 | 548 | final byte[] classData = generateInterfaceClassData(targetMethodName, callSiteSignature, lambdaMethod, interfaceSignature, flags, bridgeInterfaceSignature, className); |
498 | 549 |
|
499 | 550 | // Java 8 forces us to use a custom classloader for this; ideally we'd be leveraging hidden classes |
500 | | - final ClassLoader lambdaLoader = LambdaClassLoader.spinLoader(binaryName, classData, callerLookup); |
| 551 | + final LambdaClassLoader lookupLoader = LambdaClassLoader.findLoader(callerLookup); |
| 552 | + final LambdaClassLoader lambdaLoader = lookupLoader.registerLambda(binaryName, classData); |
501 | 553 | try { |
502 | 554 | return Class.forName(binaryName, false, lambdaLoader); |
503 | 555 | } catch (final ClassNotFoundException e) { |
|
0 commit comments