@@ -38,22 +38,34 @@ class MemoryCache(
38
38
get() = lruCache.size()
39
39
40
40
override fun loadRecord (key : String , cacheHeaders : CacheHeaders ): Record ? = lock.lock {
41
- val cacheEntry = lruCache[key]?.also { cacheEntry ->
42
- if (cacheEntry.isExpired || cacheHeaders.hasHeader(ApolloCacheHeaders .EVICT_AFTER_READ )) {
43
- lruCache.remove(key)
44
- }
45
- }
46
-
47
- cacheEntry?.takeUnless { it.isExpired }?.record ? : nextCache?.loadRecord(key, cacheHeaders)?.also { nextCachedRecord ->
41
+ val record = internalLoadRecord(key, cacheHeaders)
42
+ record ? : nextCache?.loadRecord(key, cacheHeaders)?.also { nextCachedRecord ->
48
43
lruCache[key] = CacheEntry (
49
44
record = nextCachedRecord,
50
45
expireAfterMillis = expireAfterMillis
51
46
)
52
47
}
53
48
}
54
49
55
- override fun loadRecords (keys : Collection <String >, cacheHeaders : CacheHeaders ): Collection <Record > {
56
- return keys.mapNotNull { key -> loadRecord(key, cacheHeaders) }
50
+ override fun loadRecords (keys : Collection <String >, cacheHeaders : CacheHeaders ): Collection <Record > = lock.lock {
51
+ val recordsByKey: Map <String , Record ?> = keys.associateWith { key -> internalLoadRecord(key, cacheHeaders) }
52
+ val missingKeys = recordsByKey.filterValues { it == null }.keys
53
+ val nextCachedRecords = nextCache?.loadRecords(missingKeys, cacheHeaders).orEmpty()
54
+ for (record in nextCachedRecords) {
55
+ lruCache[record.key] = CacheEntry (
56
+ record = record,
57
+ expireAfterMillis = expireAfterMillis
58
+ )
59
+ }
60
+ recordsByKey.values.filterNotNull() + nextCachedRecords
61
+ }
62
+
63
+ private fun internalLoadRecord (key : String , cacheHeaders : CacheHeaders ): Record ? {
64
+ return lruCache[key]?.also { cacheEntry ->
65
+ if (cacheEntry.isExpired || cacheHeaders.hasHeader(ApolloCacheHeaders .EVICT_AFTER_READ )) {
66
+ lruCache.remove(key)
67
+ }
68
+ }?.takeUnless { it.isExpired }?.record
57
69
}
58
70
59
71
override fun clearAll () {
@@ -79,7 +91,7 @@ class MemoryCache(
79
91
var total = 0
80
92
val keys = HashSet (lruCache.keys()) // local copy to avoid concurrent modification
81
93
keys.forEach {
82
- if (regex.matches(it)){
94
+ if (regex.matches(it)) {
83
95
lruCache.remove(it)
84
96
total++
85
97
}
@@ -137,7 +149,7 @@ class MemoryCache(
137
149
138
150
private class CacheEntry (
139
151
val record : Record ,
140
- val expireAfterMillis : Long
152
+ val expireAfterMillis : Long ,
141
153
) {
142
154
val cachedAtMillis: Long = currentTimeMillis()
143
155
0 commit comments