Skip to content

Commit 9cd0c26

Browse files
authored
feat: Store & Cache messages when peer is choked (#181)
2 parents 9b6475a + 8ac653b commit 9cd0c26

File tree

2 files changed

+133
-17
lines changed

2 files changed

+133
-17
lines changed

crates/libtortillas/src/peer/actor.rs

Lines changed: 104 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use std::{
2-
collections::HashMap,
2+
collections::{HashMap, VecDeque},
33
sync::{Arc, atomic::AtomicU8},
44
time::Instant,
55
};
@@ -26,6 +26,8 @@ use crate::{
2626
torrent::{TorrentActor, TorrentMessage, TorrentRequest, TorrentResponse},
2727
};
2828

29+
const MAX_PENDING_MESSAGES: usize = 8;
30+
2931
const PEER_KEEPALIVE_TIMEOUT: u64 = 10;
3032
const PEER_DISCONNECT_TIMEOUT: u64 = 20;
3133

@@ -39,6 +41,7 @@ pub(crate) struct PeerActor {
3941
supervisor: ActorRef<TorrentActor>,
4042

4143
pending_block_requests: Arc<DashSet<(usize, usize, usize)>>,
44+
pending_message_requests: VecDeque<PeerMessages>,
4245
}
4346

4447
impl PeerActor {
@@ -168,7 +171,7 @@ impl PeerActor {
168171
None,
169172
);
170173

171-
if let Err(e) = self.stream.send(message).await {
174+
if let Err(e) = self.send_message(message).await {
172175
trace!(error = %e, piece, "Failed to send metadata request");
173176
}
174177
} else {
@@ -223,6 +226,87 @@ impl PeerActor {
223226

224227
self.peer.set_am_interested(has_interesting_pieces);
225228
}
229+
230+
/// Sends all queued messages to the peer. This sends synchronously, and will
231+
/// not return until each message has been sent. This is because most of
232+
/// the time we want the messages to be sent in their original order.
233+
#[instrument(skip(self), fields(peer_addr = %self.stream, peer_id = %self.peer.id.unwrap()))]
234+
async fn flush_queue(&mut self) {
235+
if self.pending_message_requests.is_empty() {
236+
return;
237+
}
238+
239+
let queued_messages = self.pending_message_requests.len();
240+
241+
while let Some(msg) = self.pending_message_requests.pop_back() {
242+
self
243+
.stream
244+
.send(msg)
245+
.await
246+
.expect("Failed to send message to peer");
247+
}
248+
249+
trace!(amount = queued_messages, "Flushed queued messages to peer");
250+
}
251+
252+
/// Flushes/resends all pending block requests to the peer.
253+
#[instrument(skip(self), fields(peer_addr = %self.stream, peer_id = %self.peer.id.unwrap()))]
254+
async fn flush_block_requests(&mut self) {
255+
if self.pending_block_requests.is_empty() {
256+
return;
257+
}
258+
259+
let queued_block_requests = self.pending_block_requests.len();
260+
let mut completed = 0usize;
261+
262+
for request in self.pending_block_requests.iter() {
263+
let (index, begin, length) = *request;
264+
if self
265+
.stream
266+
.send(PeerMessages::Request(
267+
index as u32,
268+
begin as u32,
269+
length as u32,
270+
))
271+
.await
272+
.is_ok()
273+
{
274+
completed += 1;
275+
}
276+
}
277+
trace!(
278+
amount = queued_block_requests,
279+
amount_succussful = completed,
280+
"Flushed queued block requests to peer"
281+
);
282+
}
283+
284+
/// Send a message to the peer. Checks if the peer is choked, and if so,
285+
/// queues the message in [`self.pending_message_requests`]. This function
286+
/// will NOT queue request messages since they have their own queue of
287+
/// sorts.
288+
///
289+
/// Unless you're doing something like a `KeepAlive` message or a piece
290+
/// request, you should use this function over [`Self::stream.send`].
291+
#[instrument(skip(self), fields(peer_addr = %self.stream, peer_id = %self.peer.id.unwrap()))]
292+
async fn send_message(&mut self, msg: PeerMessages) -> Result<(), PeerActorError> {
293+
if self.peer.am_choked() {
294+
// Only push the message if it's not a request
295+
if matches!(msg, PeerMessages::Request(..)) {
296+
return Ok(());
297+
}
298+
if self.pending_message_requests.len() >= MAX_PENDING_MESSAGES {
299+
self.pending_message_requests.pop_back();
300+
}
301+
302+
self.pending_message_requests.push_front(msg);
303+
trace!("Peer is choked, queueing message");
304+
305+
return Ok(());
306+
}
307+
308+
self.stream.send(msg).await
309+
}
226310
}
227311

228312
impl Actor for PeerActor {
@@ -251,6 +335,7 @@ impl Actor for PeerActor {
251335
stream,
252336
supervisor,
253337
pending_block_requests: Arc::new(DashSet::new()),
338+
pending_message_requests: VecDeque::with_capacity(MAX_PENDING_MESSAGES),
254339
})
255340
}
256341

@@ -336,6 +421,10 @@ impl Message<PeerMessages> for PeerActor {
336421
PeerMessages::Unchoke => {
337422
self.peer.update_last_optimistic_unchoke();
338423
self.peer.set_am_choked(false);
424+
425+
// Send all pending messages
426+
self.flush_queue().await;
427+
self.flush_block_requests().await;
339428
trace!("Peer unchoked us");
340429
}
341430
PeerMessages::Interested => {
@@ -450,15 +539,17 @@ impl Message<PeerTell> for PeerActor {
450539
return;
451540
}
452541

453-
self
454-
.stream
455-
.send(PeerMessages::Request(
456-
index as u32,
457-
begin as u32,
458-
length as u32,
459-
))
460-
.await
461-
.expect("Failed to send piece request");
542+
if !self.peer.am_choked() {
543+
self
544+
.stream
545+
.send(PeerMessages::Request(
546+
index as u32,
547+
begin as u32,
548+
length as u32,
549+
))
550+
.await
551+
.expect("Failed to send piece request");
552+
}
462553
self.pending_block_requests.insert((index, begin, length));
463554
trace!(piece_index = index, "Sent piece request to peer");
464555
}
@@ -485,14 +576,13 @@ impl Message<PeerTell> for PeerActor {
485576
}
486577
PeerTell::HaveInfoDict(bitfield) => {
487578
self
488-
.stream
489-
.send(PeerMessages::Bitfield(bitfield))
579+
.send_message(PeerMessages::Bitfield(bitfield))
490580
.await
491581
.expect("Failed to send bitfield");
492582
trace!("Sent bitfield to peer");
493583
}
494584
PeerTell::Have(piece) => {
495-
if let Err(e) = self.stream.send(PeerMessages::Have(piece as u32)).await {
585+
if let Err(e) = self.send_message(PeerMessages::Have(piece as u32)).await {
496586
trace!(piece_num = piece, error = %e, "Failed to send Have message to peer");
497587
}
498588
}

crates/libtortillas/src/protocol/messages.rs

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use core::hash;
12
use std::{
23
collections::HashMap,
34
fmt::Display,
@@ -23,7 +24,7 @@ use crate::{
2324
peer::{MAGIC_STRING, PeerId},
2425
};
2526

26-
#[derive(Debug, Clone, PartialEq, Eq)]
27+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
2728
#[repr(u8)]
2829
/// Represents messages exchanged between peers in the BitTorrent protocol.
2930
///
@@ -366,7 +367,8 @@ impl PeerMessages {
366367
PartialEq,
367368
Eq,
368369
Deserialize_repr,
369-
TryFromPrimitive
370+
TryFromPrimitive,
371+
Hash
370372
)]
371373
#[repr(u8)]
372374
pub enum ExtendedMessageType {
@@ -456,6 +458,30 @@ pub struct ExtendedMessage {
456458
pub total_size: Option<usize>,
457459
}
458460

461+
impl hash::Hash for ExtendedMessage {
462+
fn hash<H: hash::Hasher>(&self, state: &mut H) {
463+
if let Some(extensions) = &self.supported_extensions {
464+
let mut pairs: Vec<_> = extensions.iter().collect();
465+
pairs.sort_by_key(|i| i.0);
466+
467+
pairs.hash(state);
468+
}
469+
470+
self.local_port.hash(state);
471+
self.version.hash(state);
472+
self.your_ip.hash(state);
473+
self.ipv6.hash(state);
474+
self.ipv4.hash(state);
475+
self.outstanding_requests.hash(state);
476+
self.metadata_size.hash(state);
477+
if let Some(msg_type) = &self.msg_type {
478+
msg_type.hash(state);
479+
}
480+
self.piece.hash(state);
481+
self.total_size.hash(state);
482+
}
483+
}
484+
459485
impl ExtendedMessage {
460486
pub fn new() -> Self {
461487
Self::default()
@@ -502,7 +528,7 @@ impl ExtendedMessage {
502528
}
503529

504530
/// BitTorrent Handshake message structure
505-
#[derive(Debug, Clone, PartialEq, Eq)]
531+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
506532
pub struct Handshake {
507533
/// Protocol identifier (typically "BitTorrent protocol")
508534
pub protocol: Bytes,

0 commit comments

Comments
 (0)