Skip to content

Commit 544a784

Browse files
Fix race-condition and make key mapping testable separately.
1 parent 15fd0f1 commit 544a784

File tree

3 files changed

+181
-60
lines changed

3 files changed

+181
-60
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.blobcache.shared;
9+
10+
import java.util.concurrent.ConcurrentHashMap;
11+
import java.util.function.BiConsumer;
12+
import java.util.function.Function;
13+
14+
/**
15+
* A 2 layer key mapping for the shared cache.
16+
* @param <Key1> The outer layer key type
17+
* @param <Key2> The inner key type
18+
* @param <Value> The value type
19+
*/
20+
class KeyMapping<Key1, Key2, Value> {
21+
private final ConcurrentHashMap<Key1, ConcurrentHashMap<Key2, Value>> mapping = new ConcurrentHashMap<>();
22+
23+
public Value get(Key1 key1, Key2 key2) {
24+
ConcurrentHashMap<Key2, Value> inner = mapping.get(key1);
25+
if (inner != null) {
26+
return inner.get(key2);
27+
} else {
28+
return null;
29+
}
30+
}
31+
32+
/**
33+
* Compute a key if absent. Notice that unlike CHM#computeIfAbsent, locking will be done also when present
34+
* @param key1 The key1 part
35+
* @param key2 The key2 part
36+
* @param function the function to get from key2 to the value
37+
* @return the resulting value.
38+
*/
39+
public Value computeIfAbsent(Key1 key1, Key2 key2, Function<Key2, Value> function) {
40+
var inner = mapping.compute(key1, (k, current) -> {
41+
ConcurrentHashMap<Key2, Value> map = current == null ? new ConcurrentHashMap<>() : current;
42+
map.computeIfAbsent(key2, function);
43+
return map;
44+
});
45+
return inner.get(key2);
46+
}
47+
48+
public boolean remove(Key1 key1, Key2 key2, Value value) {
49+
ConcurrentHashMap<Key2, Value> inner = mapping.get(key1);
50+
if (inner != null) {
51+
boolean removed = inner.remove(key2, value);
52+
if (removed && inner.isEmpty()) {
53+
mapping.computeIfPresent(key1, (k, v) -> v.isEmpty() ? null : v);
54+
}
55+
return removed;
56+
}
57+
return false;
58+
}
59+
60+
Iterable<Key1> key1s() {
61+
return mapping.keySet();
62+
}
63+
64+
void forEach(Key1 key1, BiConsumer<Key2, Value> consumer) {
65+
ConcurrentHashMap<Key2, Value> map = mapping.get(key1);
66+
if (map != null) {
67+
map.forEach(consumer);
68+
}
69+
}
70+
71+
void forEach(BiConsumer<Key2, Value> consumer) {
72+
for (ConcurrentHashMap<Key2, Value> map : mapping.values()) {
73+
map.forEach(consumer);
74+
}
75+
}
76+
}

x-pack/plugin/blob-cache/src/main/java/org/elasticsearch/blobcache/shared/SharedBlobCacheService.java

Lines changed: 18 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,9 +1644,7 @@ void touch() {
16441644
}
16451645
}
16461646

1647-
// only put/remove through computeXXX to ensure we do not lose the data.
1648-
private final ConcurrentHashMap<ShardId, ConcurrentHashMap<RegionKey<KeyType>, LFUCacheEntry>> keyMapping =
1649-
new ConcurrentHashMap<>();
1647+
private final KeyMapping<ShardId, RegionKey<KeyType>, LFUCacheEntry> keyMapping = new KeyMapping<>();
16501648
private final LFUCacheEntry[] freqs;
16511649
private final int maxFreq;
16521650
private final DecayAndNewEpochTask decayAndNewEpochTask;
@@ -1666,7 +1664,7 @@ public void close() {
16661664
}
16671665

16681666
int getFreq(CacheFileRegion<KeyType> cacheFileRegion) {
1669-
return keyMapping.get(cacheFileRegion.regionKey.file().shardId()).get(cacheFileRegion.regionKey).freq;
1667+
return keyMapping.get(cacheFileRegion.regionKey.file().shardId(), cacheFileRegion.regionKey).freq;
16701668
}
16711669

16721670
@Override
@@ -1675,11 +1673,11 @@ public LFUCacheEntry get(KeyType cacheKey, long fileLength, int region) {
16751673
final long now = epoch.get();
16761674
// try to just get from the map on the fast-path to save instantiating the capturing lambda needed on the slow path
16771675
// if we did not find an entry
1678-
var perShardMapping = keyMapping.computeIfAbsent(cacheKey.shardId(), key -> new ConcurrentHashMap<>());
1679-
var entry = perShardMapping.get(regionKey);
1676+
var entry = keyMapping.get(cacheKey.shardId(), regionKey);
16801677
if (entry == null) {
16811678
final int effectiveRegionSize = computeCacheFileRegionSize(fileLength, region);
1682-
entry = perShardMapping.computeIfAbsent(
1679+
entry = keyMapping.computeIfAbsent(
1680+
cacheKey.shardId(),
16831681
regionKey,
16841682
key -> new LFUCacheEntry(new CacheFileRegion<KeyType>(SharedBlobCacheService.this, key, effectiveRegionSize), now)
16851683
);
@@ -1705,12 +1703,10 @@ public LFUCacheEntry get(KeyType cacheKey, long fileLength, int region) {
17051703
@Override
17061704
public int forceEvict(Predicate<KeyType> cacheKeyPredicate) {
17071705
final List<LFUCacheEntry> matchingEntries = new ArrayList<>();
1708-
keyMapping.forEach((shard, value) -> {
1709-
value.forEach((key, entry) -> {
1710-
if (cacheKeyPredicate.test(key.file)) {
1711-
matchingEntries.add(entry);
1712-
}
1713-
});
1706+
keyMapping.forEach((key, value) -> {
1707+
if (cacheKeyPredicate.test(key.file)) {
1708+
matchingEntries.add(value);
1709+
}
17141710
});
17151711
var evictedCount = 0;
17161712
var nonZeroFrequencyEvictedCount = 0;
@@ -1721,9 +1717,7 @@ public int forceEvict(Predicate<KeyType> cacheKeyPredicate) {
17211717
boolean evicted = entry.chunk.forceEvict();
17221718
if (evicted && entry.chunk.volatileIO() != null) {
17231719
unlink(entry);
1724-
// todo: can this be null? Should not, need to assert.
1725-
ShardId shard = entry.chunk.regionKey.file.shardId();
1726-
removeKeyMappingForEntry(entry, shard);
1720+
keyMapping.remove(entry.chunk.regionKey.file.shardId(), entry.chunk.regionKey, entry);
17271721
evictedCount++;
17281722
if (frequency > 0) {
17291723
nonZeroFrequencyEvictedCount++;
@@ -1737,26 +1731,7 @@ public int forceEvict(Predicate<KeyType> cacheKeyPredicate) {
17371731
}
17381732

17391733
private boolean removeKeyMappingForEntry(LFUCacheEntry entry) {
1740-
return removeKeyMappingForEntry(entry, entry.chunk.regionKey.file().shardId());
1741-
}
1742-
1743-
private boolean removeKeyMappingForEntry(LFUCacheEntry entry, ShardId shard) {
1744-
ConcurrentHashMap<RegionKey<KeyType>, LFUCacheEntry> map = keyMapping.get(shard);
1745-
if (map != null) {
1746-
boolean removed = map.remove(entry.chunk.regionKey, entry);
1747-
if (map.isEmpty()) {
1748-
keyMapping.computeIfPresent(shard, (shard1, entries) -> {
1749-
if (entries.isEmpty()) {
1750-
return null;
1751-
} else {
1752-
return entries;
1753-
}
1754-
});
1755-
}
1756-
return removed;
1757-
} else {
1758-
return false;
1759-
}
1734+
return keyMapping.remove(entry.chunk.regionKey.file().shardId(), entry.chunk.regionKey, entry);
17601735
}
17611736

17621737
@Override
@@ -1782,45 +1757,29 @@ public void onFailure(Exception e) {
17821757
@Override
17831758
public int forceEvict(ShardId shard, Predicate<KeyType> cacheKeyPredicate) {
17841759
final List<LFUCacheEntry> matchingEntries = new ArrayList<>();
1785-
ConcurrentHashMap<RegionKey<KeyType>, LFUCacheEntry> entries = keyMapping.get(shard);
1786-
if (entries != null) {
1787-
entries.forEach((key, entry) -> {
1788-
if (cacheKeyPredicate.test(key.file)) {
1789-
matchingEntries.add(entry);
1790-
}
1791-
});
1792-
}
1760+
keyMapping.forEach((key, entry) -> {
1761+
if (cacheKeyPredicate.test(key.file)) {
1762+
matchingEntries.add(entry);
1763+
}
1764+
});
17931765

17941766
var evictedCount = 0;
17951767
var nonZeroFrequencyEvictedCount = 0;
17961768
if (matchingEntries.isEmpty() == false) {
1797-
// todo: can this be null? Should not, need to assert.
1798-
ConcurrentHashMap<RegionKey<KeyType>, LFUCacheEntry> map = keyMapping.get(shard);
17991769
synchronized (SharedBlobCacheService.this) {
18001770
for (LFUCacheEntry entry : matchingEntries) {
18011771
int frequency = entry.freq;
18021772
boolean evicted = entry.chunk.forceEvict();
18031773
if (evicted && entry.chunk.volatileIO() != null) {
18041774
unlink(entry);
1805-
if (map != null) {
1806-
map.remove(entry.chunk.regionKey, entry);
1807-
}
1775+
keyMapping.remove(shard, entry.chunk.regionKey, entry);
18081776
evictedCount++;
18091777
if (frequency > 0) {
18101778
nonZeroFrequencyEvictedCount++;
18111779
}
18121780
}
18131781
}
18141782
}
1815-
if (map != null && map.isEmpty()) {
1816-
keyMapping.computeIfPresent(shard, (shard1, entries1) -> {
1817-
if (entries1.isEmpty()) {
1818-
return null;
1819-
} else {
1820-
return entries1;
1821-
}
1822-
});
1823-
}
18241783
}
18251784
blobCacheMetrics.getEvictedCountNonZeroFrequency().incrementBy(nonZeroFrequencyEvictedCount);
18261785
return evictedCount;
@@ -1829,8 +1788,7 @@ public int forceEvict(ShardId shard, Predicate<KeyType> cacheKeyPredicate) {
18291788
private LFUCacheEntry initChunk(LFUCacheEntry entry) {
18301789
assert Thread.holdsLock(entry.chunk);
18311790
RegionKey<KeyType> regionKey = entry.chunk.regionKey;
1832-
ConcurrentHashMap<RegionKey<KeyType>, LFUCacheEntry> perShardMapping = keyMapping.get(regionKey.file().shardId());
1833-
if (perShardMapping == null || perShardMapping.get(regionKey) != entry) {
1791+
if (keyMapping.get(regionKey.file().shardId(), regionKey) != entry) {
18341792
throwAlreadyClosed("no free region found (contender)");
18351793
}
18361794
// new item
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.blobcache.shared;
9+
10+
import org.elasticsearch.test.ESTestCase;
11+
12+
import java.util.HashSet;
13+
import java.util.List;
14+
import java.util.Set;
15+
import java.util.stream.IntStream;
16+
17+
public class KeyMappingTests extends ESTestCase {
18+
19+
public void testBasics() {
20+
final String k1 = randomAlphanumericOfLength(10);
21+
final String k2 = randomAlphanumericOfLength(10);
22+
final String value = randomAlphanumericOfLength(10);
23+
KeyMapping<String, String, String> mapping = new KeyMapping<>();
24+
assertNull(mapping.get(k1, k2));
25+
26+
assertEquals(value, mapping.computeIfAbsent(k1, k2, (kx) -> value));
27+
assertEquals(value, mapping.get(k1, k2));
28+
29+
mapping.computeIfAbsent(k1, k2, (kx) -> { throw new AssertionError(); });
30+
31+
assertEquals(value, mapping.get(k1, k2));
32+
33+
final String k12 = randomValueOtherThan(k1, () -> randomAlphanumericOfLength(10));
34+
mapping.computeIfAbsent(k12, k2, (kx) -> randomAlphanumericOfLength(10));
35+
36+
assertEquals(value, mapping.get(k1, k2));
37+
38+
assertEquals(Set.of(k1, k12), mapping.key1s());
39+
40+
Set<String> values = new HashSet<>();
41+
mapping.forEach(k1, (ak2, result) -> { assertTrue(values.add(result)); });
42+
assertEquals(Set.of(value), values);
43+
44+
assertTrue(mapping.remove(k1, k2, value));
45+
46+
assertEquals(Set.of(k12), mapping.key1s());
47+
48+
assertNull(mapping.get(k1, k2));
49+
assertNotNull(mapping.get(k12, k2));
50+
51+
assertFalse(mapping.remove(k1, k2, value));
52+
}
53+
54+
public void testMultiThreaded() {
55+
final String k1 = randomAlphanumericOfLength(10);
56+
KeyMapping<String, String, Integer> mapping = new KeyMapping<>();
57+
58+
List<Thread> threads = IntStream.range(0, 10).mapToObj(i -> new Thread(() -> {
59+
final String k2 = Integer.toString(i);
60+
logger.info(k2);
61+
62+
for (int j = 0; j < 1000; ++j) {
63+
Integer finalJ = j;
64+
assertNull(mapping.get(k1, k2));
65+
mapping.computeIfAbsent(k1, k2, (kx) -> finalJ);
66+
assertEquals(finalJ, mapping.get(k1, k2));
67+
assertTrue(mapping.remove(k1, k2, finalJ));
68+
if ((j & 1) == 0) {
69+
assertFalse(mapping.remove(k1, k2, finalJ));
70+
}
71+
72+
}
73+
assertNull(mapping.get(k1, k2));
74+
}, "test-thread-" + i)).toList();
75+
76+
threads.forEach(Thread::start);
77+
threads.forEach(t -> {
78+
try {
79+
t.join(10000);
80+
} catch (InterruptedException e) {
81+
throw new RuntimeException(e);
82+
}
83+
});
84+
85+
assertEquals(Set.of(), mapping.key1s());
86+
}
87+
}

0 commit comments

Comments
 (0)