Skip to content

Commit 0437aad

Browse files
committed
Cancel safety
1 parent 3873fce commit 0437aad

File tree

1 file changed

+56
-16
lines changed

1 file changed

+56
-16
lines changed

crates/core/src/client/client_connection.rs

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,24 @@ impl DurableOffsetSupply for RelationalDB {
148148
pub struct ClientConnectionReceiver {
149149
confirmed_reads: bool,
150150
channel: MeteredReceiver<ClientUpdate>,
151+
current: Option<ClientUpdate>,
151152
offset_supply: Box<dyn DurableOffsetSupply>,
152153
}
153154

154155
impl ClientConnectionReceiver {
156+
fn new(
157+
confirmed_reads: bool,
158+
channel: MeteredReceiver<ClientUpdate>,
159+
offset_supply: impl DurableOffsetSupply + 'static,
160+
) -> Self {
161+
Self {
162+
confirmed_reads,
163+
channel,
164+
current: None,
165+
offset_supply: Box::new(offset_supply),
166+
}
167+
}
168+
155169
/// Receive the next message from this channel.
156170
///
157171
/// If this method returns `None`, the channel is closed and no more messages
@@ -173,17 +187,36 @@ impl ClientConnectionReceiver {
173187
/// If the database is shut down while waiting for the durable offset,
174188
/// `None` is returned. In this case, no more messages can ever be received
175189
/// from the channel.
190+
///
191+
/// # Cancel safety
192+
///
193+
/// This method is cancel safe, as long as `self` is not dropped.
194+
///
195+
/// If `recv` is used in a [`tokio::select!`] statement, it may get
196+
/// cancelled while waiting for the durable offset to catch up. At this
197+
/// point, it has already received a value from the underlying channel.
198+
/// This value is stored internally, so calling `recv` again will not lose
199+
/// data.
176200
//
177201
// TODO: Can we make a cancel-safe `recv_many` with confirmed reads semantics?
178202
pub async fn recv(&mut self) -> Option<SerializableMessage> {
179-
let ClientUpdate { tx_offset, message } = self.channel.recv().await?;
203+
let ClientUpdate { tx_offset, message } = match self.current.take() {
204+
None => self.channel.recv().await?,
205+
Some(update) => update,
206+
};
180207
if !self.confirmed_reads {
181208
return Some(message);
182209
}
183210

184211
if let Some(tx_offset) = tx_offset {
185212
match self.offset_supply.durable_offset() {
186213
Ok(Some(mut durable)) => {
214+
// Store the current update in case we get cancelled while
215+
// waiting for the durable offset.
216+
self.current = Some(ClientUpdate {
217+
tx_offset: Some(tx_offset),
218+
message,
219+
});
187220
trace!("waiting for offset {tx_offset} to become durable");
188221
durable
189222
.wait_for(tx_offset)
@@ -192,16 +225,16 @@ impl ClientConnectionReceiver {
192225
warn!("database went away while waiting for durable offset");
193226
})
194227
.ok()?;
228+
self.current.take().map(|update| update.message)
195229
}
196230
// Database shut down or crashed.
197-
Err(NoSuchModule) => return None,
231+
Err(NoSuchModule) => None,
198232
// In-memory database.
199-
Ok(None) => return Some(message),
233+
Ok(None) => Some(message),
200234
}
235+
} else {
236+
Some(message)
201237
}
202-
203-
trace!("returning durable message");
204-
Some(message)
205238
}
206239

207240
/// Close the receiver without dropping it.
@@ -290,11 +323,7 @@ impl ClientConnectionSender {
290323
Err(_) => tokio::runtime::Runtime::new().unwrap().spawn(async {}).abort_handle(),
291324
};
292325

293-
let receiver = ClientConnectionReceiver {
294-
confirmed_reads: config.confirmed_reads,
295-
channel: MeteredReceiver::new(rx),
296-
offset_supply: Box::new(offset_supply),
297-
};
326+
let receiver = ClientConnectionReceiver::new(config.confirmed_reads, MeteredReceiver::new(rx), offset_supply);
298327
let cancelled = AtomicBool::new(false);
299328
let sender = Self {
300329
id,
@@ -666,11 +695,11 @@ impl ClientConnection {
666695
.abort_handle();
667696

668697
let metrics = ClientConnectionMetrics::new(database_identity, config.protocol);
669-
let receiver = ClientConnectionReceiver {
670-
confirmed_reads: config.confirmed_reads,
671-
channel: MeteredReceiver::with_gauge(sendrx, metrics.sendtx_queue_size.clone()),
672-
offset_supply: Box::new(module_rx.clone()),
673-
};
698+
let receiver = ClientConnectionReceiver::new(
699+
config.confirmed_reads,
700+
MeteredReceiver::with_gauge(sendrx, metrics.sendtx_queue_size.clone()),
701+
module_rx.clone(),
702+
);
674703

675704
let sender = Arc::new(ClientConnectionSender {
676705
id,
@@ -1110,4 +1139,15 @@ mod tests {
11101139
assert_received_update(receiver.recv()).await;
11111140
}
11121141
}
1142+
1143+
#[tokio::test]
1144+
async fn client_connection_receiver_cancel_safety() {
1145+
let offset = FakeDurableOffset::new();
1146+
let (sender, mut receiver) = client_with_confirmed_reads(offset.clone());
1147+
1148+
sender.send_message(Some(3), empty_tx_update()).unwrap();
1149+
assert_pending(&mut pin!(receiver.recv())).await;
1150+
offset.mark_durable_at(3);
1151+
assert_received_update(receiver.recv()).await;
1152+
}
11131153
}

0 commit comments

Comments
 (0)