Skip to content

Commit 3e3b3c3

Browse files
authored
fix: data race when doing a snapshot in cache mode (#5336)
The issue - during the snapshotting we tried to fetch the same delayed entry which lead to data-races and deadlocks. The fix - to pull the entry into a stack before preempting, ensuring each fiber handles its own entry. Fixes #4965 Signed-off-by: Roman Gershman <[email protected]>
1 parent 2ad1099 commit 3e3b3c3

File tree

5 files changed

+49
-14
lines changed

5 files changed

+49
-14
lines changed

src/server/snapshot.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ void SliceSnapshot::Start(bool stream_journal, SnapshotFlush allow_flush) {
104104
}
105105

106106
void SliceSnapshot::StartIncremental(LSN start_lsn) {
107+
VLOG(1) << "StartIncremental: " << start_lsn;
107108
serializer_ = std::make_unique<RdbSerializer>(compression_mode_);
108109

109110
snapshot_fb_ = fb2::Fiber("incremental_snapshot",
@@ -319,7 +320,6 @@ size_t SliceSnapshot::FlushSerialized(SerializerBase::FlushState flush_state) {
319320
uint64_t running_cycles = ThisFiber::GetRunningTimeCycles();
320321

321322
fb2::NoOpLock lk;
322-
323323
// We create a critical section here that ensures that records are pushed in sequential order.
324324
// As a result, it is not possible for two fiber producers to push concurrently.
325325
// If A.id = 5, and then B.id = 6, and both are blocked here, it means that last_pushed_id_ < 4.
@@ -356,7 +356,10 @@ bool SliceSnapshot::PushSerialized(bool force) {
356356
// Async bucket serialization might have accumulated some delayed values.
357357
// Because we can finally block in this function, we'll await and serialize them
358358
do {
359-
auto& entry = delayed_entries_.back();
359+
// We may call PushSerialized from multiple fibers concurrently, so we need to
360+
// ensure that we are not serializing the same entry concurrently.
361+
DelayedEntry entry = std::move(delayed_entries_.back());
362+
delayed_entries_.pop_back();
360363

361364
// TODO: https://github.com/dragonflydb/dragonfly/issues/4654
362365
// there are a few problems with how we serialize external values.
@@ -367,7 +370,6 @@ bool SliceSnapshot::PushSerialized(bool force) {
367370

368371
// TODO: to introduce RdbSerializer::SaveString that can accept a string value directly.
369372
serializer_->SaveEntry(entry.key, pv, entry.expire, entry.mc_flags, entry.dbid);
370-
delayed_entries_.pop_back();
371373
} while (!delayed_entries_.empty());
372374

373375
// blocking point.

src/server/snapshot.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class SliceSnapshot : public journal::JournalConsumerInterface {
6868

6969
// Initialize snapshot, start bucket iteration fiber, register listeners.
7070
// In journal streaming mode it needs to be stopped by either Stop or Cancel.
71-
enum class SnapshotFlush { kAllow, kDisallow };
71+
enum class SnapshotFlush : uint8_t { kAllow, kDisallow };
7272

7373
void Start(bool stream_journal, SnapshotFlush allow_flush = SnapshotFlush::kDisallow);
7474

tests/dragonfly/replication_test.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,6 @@
2626
M_NOT_EPOLL = [pytest.mark.exclude_epoll]
2727

2828

29-
async def wait_for_replicas_state(*clients, state="online", node_role="slave", timeout=0.05):
30-
"""Wait until all clients (replicas) reach passed state"""
31-
while len(clients) > 0:
32-
await asyncio.sleep(timeout)
33-
roles = await asyncio.gather(*(c.role() for c in clients))
34-
clients = [c for c, role in zip(clients, roles) if role[0] != node_role or role[3] != state]
35-
36-
3729
"""
3830
Test full replication pipeline. Test full sync with streaming changes and stable state streaming.
3931
"""

tests/dragonfly/tiering_test.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import async_timeout
22
import asyncio
33
import itertools
4+
import logging
45
import pytest
56
import random
67
import redis.asyncio as aioredis
78

89
from . import dfly_args
910
from .seeder import DebugPopulateSeeder
10-
from .utility import info_tick_timer
11-
11+
from .utility import info_tick_timer, wait_for_replicas_state
12+
from .instance import DflyInstanceFactory
1213

1314
BASIC_ARGS = {"port": 6379, "proactor_threads": 4, "tiered_prefix": "/tmp/tiered/backing"}
1415

@@ -87,3 +88,35 @@ async def run(sub_ops):
8788
res = await p.execute()
8889

8990
assert res == [10 * k for k in key_range]
91+
92+
93+
@pytest.mark.exclude_epoll
94+
@pytest.mark.opt_only
95+
@dfly_args(
96+
{
97+
"proactor_threads": 2,
98+
"tiered_prefix": "/tmp/tiered/backing_master",
99+
"maxmemory": "4G",
100+
"cache_mode": True,
101+
"tiered_offload_threshold": "0.2",
102+
"tiered_storage_write_depth": 100,
103+
}
104+
)
105+
async def test_full_sync(async_client: aioredis.Redis, df_factory: DflyInstanceFactory):
106+
replica = df_factory.create(
107+
proactor_threads=2,
108+
cache_mode=True,
109+
maxmemory="4G",
110+
tiered_prefix="/tmp/tiered/backing_replica",
111+
tiered_offload_threshold="0.2",
112+
tiered_storage_write_depth=1000,
113+
)
114+
replica.start()
115+
replica_client = replica.client()
116+
await async_client.execute_command("debug", "populate", "3000000", "key", "2000")
117+
await replica_client.replicaof(
118+
"localhost", async_client.connection_pool.connection_kwargs["port"]
119+
)
120+
logging.info("Waiting for replica to sync")
121+
async with async_timeout.timeout(120):
122+
await wait_for_replicas_state(replica_client)

tests/dragonfly/utility.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -848,3 +848,11 @@ def extract_int_after_prefix(prefix, line):
848848
match = re.search(prefix + "(\\d+)", line)
849849
assert match
850850
return int(match.group(1))
851+
852+
853+
async def wait_for_replicas_state(*clients, state="online", node_role="slave", timeout=0.05):
854+
"""Wait until all clients (replicas) reach passed state"""
855+
while len(clients) > 0:
856+
await asyncio.sleep(timeout)
857+
roles = await asyncio.gather(*(c.role() for c in clients))
858+
clients = [c for c, role in zip(clients, roles) if role[0] != node_role or role[3] != state]

0 commit comments

Comments
 (0)