Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions examples/protocol_id.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//! Example demonstrating network isolation using protocol IDs.
//!
//! This example shows how to create isolated DHT networks that don't interfere
//! with each other using the protocol ID.
//!
//! Run with:
//! ```bash
//! cargo run --example protocol_id
//! ```

use mainline::{Dht, DEFAULT_STAGING_PROTOCOL_ID};

fn main() -> Result<(), std::io::Error> {
println!("Protocol ID Example\n");

// Example 1: Default behavior - participates in main BitTorrent DHT
println!("1. Creating DHT node without protocol ID (default behavior):");
println!(" This node will communicate with the main BitTorrent DHT network");
let _default_dht = Dht::client()?;
println!(" ✓ Created\n");

// Example 2: Using a custom protocol ID for an isolated network
println!("2. Creating DHT node with custom protocol ID:");
println!(" Protocol ID: /myapp/mainline/1.0.0");
println!(" This node will only communicate with other nodes using the same protocol ID");
let _custom_dht = Dht::builder()
.protocol_id("/myapp/mainline/1.0.0")
.build()?;
println!(" ✓ Created\n");

// Example 3: Using the default staging protocol ID constant
println!("4. Creating DHT node using DEFAULT_STAGING_PROTOCOL_ID:");
println!(" Protocol ID: {}", DEFAULT_STAGING_PROTOCOL_ID);
println!(" Useful for creating isolated test networks");
let _staging_dht = Dht::builder()
.protocol_id(DEFAULT_STAGING_PROTOCOL_ID)
.build()?;
println!(" ✓ Created\n");

println!("Network Isolation Rules:");
println!("• Nodes with the same protocol ID can communicate");
println!("• Nodes with different protocol IDs ignore each other's messages");
println!("• Nodes without a protocol ID (None) accept all messages (backward compatible)");
println!("• Nodes with a protocol ID ONLY accept messages with matching protocol ID");

Ok(())
}
43 changes: 43 additions & 0 deletions src/common/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ pub(crate) struct Message {
/// For bep0043. When set true on a request, indicates that the requester can't reply to requests and that responders should not add requester to their routing tables.
/// Should only be set on requests - undefined behavior when set on a response.
pub read_only: bool,

/// Optional protocol ID for network isolation
pub protocol_id: Option<Box<[u8]>>,
}

#[derive(Debug, PartialEq, Clone)]
Expand Down Expand Up @@ -220,6 +223,7 @@ impl Message {
.requester_ip
.map(|sockaddr| sockaddr_to_bytes(&sockaddr)),
read_only: if self.read_only { Some(1) } else { Some(0) },
protocol_id: self.protocol_id,
variant: match self.message_type {
MessageType::Request(RequestSpecific {
requester_id,
Expand Down Expand Up @@ -412,6 +416,7 @@ impl Message {
} else {
false
},
protocol_id: msg.protocol_id,
message_type: match msg.variant {
internal::DHTMessageVariant::Request(req_variant) => {
MessageType::Request(match req_variant {
Expand Down Expand Up @@ -664,6 +669,10 @@ impl Message {
_ => None,
}
}

pub fn protocol_id(&self) -> Option<&[u8]> {
self.protocol_id.as_deref()
}
}

fn bytes_to_sockaddr<T: AsRef<[u8]>>(bytes: T) -> Result<SocketAddrV4, DecodeMessageError> {
Expand Down Expand Up @@ -781,6 +790,7 @@ mod tests {
version: None,
requester_ip: None,
read_only: false,
protocol_id: None,
message_type: MessageType::Request(RequestSpecific {
requester_id: Id::random(),
request_type: RequestTypeSpecific::Ping,
Expand All @@ -801,6 +811,7 @@ mod tests {
version: Some([0xde, 0xad, 0, 1]),
requester_ip: Some("99.100.101.102:1030".parse().unwrap()),
read_only: false,
protocol_id: None,
message_type: MessageType::Response(ResponseSpecific::Ping(PingResponseArguments {
responder_id: Id::random(),
})),
Expand All @@ -820,6 +831,7 @@ mod tests {
version: Some([0x62, 0x61, 0x72, 0x66]),
requester_ip: None,
read_only: false,
protocol_id: None,
message_type: MessageType::Request(RequestSpecific {
requester_id: Id::random(),
request_type: RequestTypeSpecific::FindNode(FindNodeRequestArguments {
Expand All @@ -842,6 +854,7 @@ mod tests {
version: Some([0x62, 0x61, 0x72, 0x66]),
requester_ip: None,
read_only: true,
protocol_id: None,
message_type: MessageType::Request(RequestSpecific {
requester_id: Id::random(),
request_type: RequestTypeSpecific::FindNode(FindNodeRequestArguments {
Expand All @@ -864,6 +877,7 @@ mod tests {
version: Some([1, 2, 3, 4]),
requester_ip: Some("50.51.52.53:5455".parse().unwrap()),
read_only: false,
protocol_id: None,
message_type: MessageType::Response(ResponseSpecific::FindNode(
FindNodeResponseArguments {
responder_id: Id::random(),
Expand Down Expand Up @@ -897,6 +911,7 @@ mod tests {
version: Some([72, 73, 0, 1]),
requester_ip: None,
read_only: false,
protocol_id: None,
message_type: MessageType::Request(RequestSpecific {
requester_id: Id::random(),
request_type: RequestTypeSpecific::GetPeers(GetPeersRequestArguments {
Expand All @@ -919,6 +934,7 @@ mod tests {
version: Some([1, 2, 3, 4]),
requester_ip: Some("50.51.52.53:5455".parse().unwrap()),
read_only: true,
protocol_id: None,
message_type: MessageType::Response(ResponseSpecific::NoValues(
NoValuesResponseArguments {
responder_id: Id::random(),
Expand Down Expand Up @@ -959,6 +975,7 @@ mod tests {
version: Some([1, 2, 3, 4]),
requester_ip: Some("50.51.52.53:5455".parse().unwrap()),
read_only: false,
protocol_id: None,
message_type: MessageType::Response(ResponseSpecific::GetPeers(
GetPeersResponseArguments {
responder_id: Id::random(),
Expand All @@ -983,6 +1000,7 @@ mod tests {
read_only: None,
transaction_id: [1, 2, 3, 4],
version: None,
protocol_id: None,
variant: internal::DHTMessageVariant::Response(
internal::DHTResponseSpecific::NoValues {
arguments: internal::DHTNoValuesResponseArguments {
Expand All @@ -1007,6 +1025,7 @@ mod tests {
version: Some([72, 73, 0, 1]),
requester_ip: None,
read_only: false,
protocol_id: None,
message_type: MessageType::Request(RequestSpecific {
requester_id: Id::random(),
request_type: RequestTypeSpecific::GetValue(GetValueRequestArguments {
Expand All @@ -1031,6 +1050,7 @@ mod tests {
version: Some([1, 2, 3, 4]),
requester_ip: Some("50.51.52.53:5455".parse().unwrap()),
read_only: false,
protocol_id: None,
message_type: MessageType::Response(ResponseSpecific::GetImmutable(
GetImmutableResponseArguments {
responder_id: Id::random(),
Expand All @@ -1055,6 +1075,7 @@ mod tests {
version: Some([1, 2, 3, 4]),
requester_ip: Some("50.51.52.53:5455".parse().unwrap()),
read_only: false,
protocol_id: None,
message_type: MessageType::Request(RequestSpecific {
requester_id: Id::random(),
request_type: RequestTypeSpecific::Put(PutRequest {
Expand Down Expand Up @@ -1083,6 +1104,7 @@ mod tests {
version: Some([1, 2, 3, 4]),
requester_ip: Some("50.51.52.53:5455".parse().unwrap()),
read_only: false,
protocol_id: None,
message_type: MessageType::Request(RequestSpecific {
requester_id: Id::random(),
request_type: RequestTypeSpecific::Put(PutRequest {
Expand All @@ -1106,4 +1128,25 @@ mod tests {
let parsed_msg = Message::from_serde_message(parsed_serde_msg).unwrap();
assert_eq!(parsed_msg, original_msg);
}

#[test]
fn test_protocol_id_request() {
// Old nodes (without protocol_id field) should be able to parse messages with protocol_id
// because bencode ignores unknown fields
let protocol_id = b"/pubky/mainline/1.0.0";
let msg_with_protocol = Message {
transaction_id: 258,
version: Some([0x62, 0x61, 0x72, 0x66]),
requester_ip: None,
read_only: false,
protocol_id: Some(protocol_id.to_vec().into_boxed_slice()),
message_type: MessageType::Request(RequestSpecific {
requester_id: Id::random(),
request_type: RequestTypeSpecific::Ping,
}),
};
let bytes = msg_with_protocol.to_bytes().unwrap();
let parsed = Message::from_bytes(&bytes).unwrap();
assert_eq!(parsed.protocol_id(), Some(protocol_id.as_ref()));
}
}
5 changes: 5 additions & 0 deletions src/common/messages/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ pub struct DHTMessage {
#[serde(default)]
#[serde(rename = "ro")]
pub read_only: Option<i32>,

#[serde(default)]
#[serde(rename = "p", with = "serde_bytes")]
/// protocol ID for network isolation
pub protocol_id: Option<Box<[u8]>>,
}

impl DHTMessage {
Expand Down
141 changes: 141 additions & 0 deletions src/dht.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,20 @@ impl DhtBuilder {
self
}

/// Set the protocol ID for network isolation
///
/// When set, this node will only communicate with other nodes using the same protocol ID.
/// Messages from nodes with different or no protocol IDs will be rejected.
///
/// Format: "/prefix/mainline/version" (e.g., "/pubky/mainline/1.0.0")
///
/// When None (default), accepts all messages for backward compatiblility.
pub fn protocol_id(&mut self, protocol_id: impl Into<String>) -> &mut Self {
self.0.protocol_id = Some(protocol_id.into());

self
}

/// Create a Dht node.
pub fn build(&self) -> Result<Dht, std::io::Error> {
Dht::new(self.0.clone())
Expand Down Expand Up @@ -1103,4 +1117,131 @@ mod test {
.iter()
.all(|n| n.to_bootstrap().len() == size - 1));
}

#[test]
fn protocol_id_isolation_different_networks_cannot_communicate() {
let testnet = Testnet::new(5).unwrap();

// Create node A with protocol ID "/network_a/mainline/1.0.0"
let node_a = Dht::builder()
.protocol_id("/network_a/mainline/1.0.0")
.bootstrap(&testnet.bootstrap)
.build()
.unwrap();

// Create node B with different protocol ID "/network_b/mainline/1.0.0"
let node_b = Dht::builder()
.protocol_id("/network_b/mainline/1.0.0")
.bootstrap(&testnet.bootstrap)
.build()
.unwrap();

// Wait for nodes to attempt bootstrapping
std::thread::sleep(std::time::Duration::from_millis(500));

// Node A puts immutable data
let value = b"Hello from Network A";
let target = node_a.put_immutable(value).unwrap();

// Node B (on different network) should NOT be able to get the data
// Because they have different protocol IDs, B's requests are ignored by A's network
let result = node_b.get_immutable(target);
assert!(
result.is_none(),
"Node B should not be able to retrieve data from Network A"
);
}

#[test]
fn protocol_id_isolation_same_network_can_communicate() {
let mut nodes: Vec<Dht> = vec![];
let mut bootstrap = vec![];

// Create first node with protocol ID
let node = Dht::builder()
.protocol_id("/test_network/mainline/1.0.0")
.server_mode()
.no_bootstrap()
.build()
.unwrap();

let info = node.info();
bootstrap.push(format!("127.0.0.1:{}", info.local_addr().port()));
nodes.push(node);

// Create more nodes with SAME protocol ID
for _ in 1..5 {
let node = Dht::builder()
.protocol_id("/test_network/mainline/1.0.0")
.server_mode()
.bootstrap(&bootstrap)
.build()
.unwrap();
nodes.push(node);
}

for node in &nodes {
node.bootstrapped();
}

// Create two client nodes with the SAME protocol ID
let node_a = Dht::builder()
.protocol_id("/test_network/mainline/1.0.0")
.bootstrap(&bootstrap)
.build()
.unwrap();
let node_b = Dht::builder()
.protocol_id("/test_network/mainline/1.0.0")
.bootstrap(&bootstrap)
.build()
.unwrap();

// Node A puts immutable data
let value = b"Hello from same network";
let target = node_a.put_immutable(value).unwrap();

// Node B (on same network) SHOULD be able to get the data
let result = node_b.get_immutable(target);
assert!(
result.is_some(),
"Node B should be able to retrieve data from Node A (same network)"
);
assert_eq!(result.unwrap().as_ref(), value);
}

#[test]
fn protocol_id_node_rejects_messages_without_protocol_id() {
let testnet = Testnet::new(5).unwrap();

// Create node with protocol ID
let node_with_protocol = Dht::builder()
.protocol_id("/custom_network/mainline/1.0.0")
.bootstrap(&testnet.bootstrap)
.build()
.unwrap();

// Create node WITHOUT protocol ID (default)
let node_without_protocol = Dht::builder()
.bootstrap(&testnet.bootstrap)
.build()
.unwrap();

// Wait for bootstrapping
std::thread::sleep(std::time::Duration::from_millis(500));

// Node without protocol ID puts data
let value = b"Hello from default network";
let target = node_without_protocol.put_immutable(value).unwrap();

// Node with protocol ID will NOT be able to read from default network nodes
// because nodes with protocol IDs reject messages without protocol IDs
let result = node_with_protocol.get_immutable(target);

// The node with protocol ID cannot retrieve data from nodes without protocol IDs
// because it rejects their responses (no protocol ID = rejected)
assert!(
result.is_none(),
"Node with protocol ID should not retrieve data from nodes without protocol ID"
);
}
}
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ pub use rpc::{

pub use ed25519_dalek::SigningKey;

/// Default protocol ID for staging isolated network
pub const DEFAULT_STAGING_PROTOCOL_ID: &str = "/pubky_staging/mainline/1.0.0";

pub mod errors {
//! Exported errors
#[cfg(feature = "node")]
Expand Down
Loading
Loading