Skip to content

Commit 7d2fc51

Browse files
committed
feat: optinally limit access to authorized_keys file
1 parent ed5b995 commit 7d2fc51

File tree

1 file changed

+108
-28
lines changed

1 file changed

+108
-28
lines changed

src/main.rs

Lines changed: 108 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
11
//! Command line arguments.
22
use clap::{Parser, Subcommand};
33
use dumbpipe::NodeTicket;
4-
use iroh::{endpoint::Connecting, Endpoint, NodeAddr, SecretKey, Watcher};
5-
use n0_snafu::{Result, ResultExt};
4+
use iroh::{
5+
endpoint::{Connecting, Connection},
6+
Endpoint, NodeAddr, NodeId, SecretKey, Watcher,
7+
};
8+
use n0_snafu::{format_err, Result, ResultExt};
69
use std::{
710
io,
811
net::{SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
12+
path::{Path, PathBuf},
913
str::FromStr,
14+
sync::Arc,
1015
};
1116
use tokio::{
1217
io::{AsyncRead, AsyncWrite, AsyncWriteExt},
1318
select,
1419
};
1520
use tokio_util::sync::CancellationToken;
21+
use tracing::info;
1622

1723
/// Create a dumb pipe between two machines, using an iroh magicsocket.
1824
///
@@ -122,19 +128,40 @@ fn parse_alpn(alpn: &str) -> Result<Vec<u8>> {
122128
})
123129
}
124130

131+
/// Arguments shared among commands accepting connections.
132+
#[derive(Parser, Debug)]
133+
pub struct CommonAcceptArgs {
134+
/// Optionally limit access to node ids listed in this file.
135+
#[clap(short = 'a', long)]
136+
pub authorized_keys: Option<PathBuf>,
137+
}
138+
139+
impl CommonAcceptArgs {
140+
async fn authorized_keys(&self) -> Result<Option<AuthorizedKeys>> {
141+
if let Some(ref path) = self.authorized_keys {
142+
Ok(Some(AuthorizedKeys::load(path).await?))
143+
} else {
144+
Ok(None)
145+
}
146+
}
147+
}
148+
125149
#[derive(Parser, Debug)]
126150
pub struct ListenArgs {
127151
#[clap(flatten)]
128152
pub common: CommonArgs,
153+
#[clap(flatten)]
154+
pub accept: CommonAcceptArgs,
129155
}
130156

131157
#[derive(Parser, Debug)]
132158
pub struct ListenTcpArgs {
133159
#[clap(long)]
134160
pub host: String,
135-
136161
#[clap(flatten)]
137162
pub common: CommonArgs,
163+
#[clap(flatten)]
164+
pub accept: CommonAcceptArgs,
138165
}
139166

140167
#[derive(Parser, Debug)]
@@ -267,6 +294,7 @@ async fn forward_bidi(
267294

268295
async fn listen_stdio(args: ListenArgs) -> Result<()> {
269296
let secret_key = get_or_create_secret()?;
297+
let authorized_keys = args.accept.authorized_keys().await?;
270298
let mut builder = Endpoint::builder()
271299
.alpns(vec![args.common.alpn()?])
272300
.secret_key(secret_key);
@@ -277,6 +305,7 @@ async fn listen_stdio(args: ListenArgs) -> Result<()> {
277305
builder = builder.bind_addr_v6(addr);
278306
}
279307
let endpoint = builder.bind().await?;
308+
eprintln!("endpoint bound with node id {}", endpoint.node_id());
280309
// wait for the endpoint to figure out its address before making a ticket
281310
endpoint.home_relay().initialized().await?;
282311
let node = endpoint.node_addr().initialized().await?;
@@ -306,7 +335,12 @@ async fn listen_stdio(args: ListenArgs) -> Result<()> {
306335
}
307336
};
308337
let remote_node_id = &connection.remote_node_id()?;
309-
tracing::info!("got connection from {}", remote_node_id);
338+
info!("got connection from {}", remote_node_id);
339+
if let Some(ref authorized_keys) = authorized_keys {
340+
if authorized_keys.authorize(&connection).is_err() {
341+
continue;
342+
}
343+
}
310344
let (s, mut r) = match connection.accept_bi().await {
311345
Ok(x) => x,
312346
Err(cause) => {
@@ -315,14 +349,14 @@ async fn listen_stdio(args: ListenArgs) -> Result<()> {
315349
continue;
316350
}
317351
};
318-
tracing::info!("accepted bidi stream from {}", remote_node_id);
352+
info!("accepted bidi stream from {}", remote_node_id);
319353
if !args.common.is_custom_alpn() {
320354
// read the handshake and verify it
321355
let mut buf = [0u8; dumbpipe::HANDSHAKE.len()];
322356
r.read_exact(&mut buf).await.e()?;
323357
snafu::ensure_whatever!(buf == dumbpipe::HANDSHAKE, "invalid handshake");
324358
}
325-
tracing::info!("forwarding stdin/stdout to {}", remote_node_id);
359+
info!("forwarding stdin/stdout to {}", remote_node_id);
326360
forward_bidi(tokio::io::stdin(), tokio::io::stdout(), r, s).await?;
327361
// stop accepting connections after the first successful one
328362
break;
@@ -341,23 +375,24 @@ async fn connect_stdio(args: ConnectArgs) -> Result<()> {
341375
builder = builder.bind_addr_v6(addr);
342376
}
343377
let endpoint = builder.bind().await?;
378+
eprintln!("endpoint bound with node id {}", endpoint.node_id());
344379
let addr = args.ticket.node_addr();
345380
let remote_node_id = addr.node_id;
346381
// connect to the node, try only once
347382
let connection = endpoint.connect(addr.clone(), &args.common.alpn()?).await?;
348-
tracing::info!("connected to {}", remote_node_id);
383+
info!("connected to {}", remote_node_id);
349384
// open a bidi stream, try only once
350-
let (mut s, r) = connection.open_bi().await.e()?;
351-
tracing::info!("opened bidi stream to {}", remote_node_id);
385+
let (mut send, recv) = connection.open_bi().await.e()?;
386+
info!("opened bidi stream to {}", remote_node_id);
352387
// send the handshake unless we are using a custom alpn
353388
// when using a custom alpn, evertyhing is up to the user
354389
if !args.common.is_custom_alpn() {
355390
// the connecting side must write first. we don't know if there will be something
356391
// on stdin, so just write a handshake.
357-
s.write_all(&dumbpipe::HANDSHAKE).await.e()?;
392+
send.write_all(&dumbpipe::HANDSHAKE).await.e()?;
358393
}
359-
tracing::info!("forwarding stdin/stdout to {}", remote_node_id);
360-
forward_bidi(tokio::io::stdin(), tokio::io::stdout(), r, s).await?;
394+
info!("forwarding stdin/stdout to {}", remote_node_id);
395+
forward_bidi(tokio::io::stdin(), tokio::io::stdout(), recv, send).await?;
361396
tokio::io::stdout().flush().await.e()?;
362397
Ok(())
363398
}
@@ -377,14 +412,12 @@ async fn connect_tcp(args: ConnectTcpArgs) -> Result<()> {
377412
builder = builder.bind_addr_v6(addr);
378413
}
379414
let endpoint = builder.bind().await.context("unable to bind magicsock")?;
380-
tracing::info!("tcp listening on {:?}", addrs);
381-
let tcp_listener = match tokio::net::TcpListener::bind(addrs.as_slice()).await {
382-
Ok(tcp_listener) => tcp_listener,
383-
Err(cause) => {
384-
tracing::error!("error binding tcp socket to {:?}: {}", addrs, cause);
385-
return Ok(());
386-
}
387-
};
415+
eprintln!("endpoint bound with node id {}", endpoint.node_id());
416+
let tcp_listener = tokio::net::TcpListener::bind(addrs.as_slice())
417+
.await
418+
.with_context(|| format!("error binding tcp socket to {:?}", addrs.as_slice()))?;
419+
info!("tcp listening on {:?}", addrs.as_slice());
420+
388421
async fn handle_tcp_accept(
389422
next: io::Result<(tokio::net::TcpStream, SocketAddr)>,
390423
addr: NodeAddr,
@@ -394,7 +427,7 @@ async fn connect_tcp(args: ConnectTcpArgs) -> Result<()> {
394427
) -> Result<()> {
395428
let (tcp_stream, tcp_addr) = next.context("error accepting tcp connection")?;
396429
let (tcp_recv, tcp_send) = tcp_stream.into_split();
397-
tracing::info!("got tcp connection from {}", tcp_addr);
430+
info!("got tcp connection from {}", tcp_addr);
398431
let remote_node_id = addr.node_id;
399432
let connection = endpoint
400433
.connect(addr, alpn)
@@ -412,8 +445,9 @@ async fn connect_tcp(args: ConnectTcpArgs) -> Result<()> {
412445
magic_send.write_all(&dumbpipe::HANDSHAKE).await.e()?;
413446
}
414447
forward_bidi(tcp_recv, tcp_send, magic_recv, magic_send).await?;
415-
Ok::<_, n0_snafu::Error>(())
448+
Ok(())
416449
}
450+
417451
let addr = args.ticket.node_addr();
418452
loop {
419453
// also wait for ctrl-c here so we can use it before accepting a connection
@@ -433,7 +467,7 @@ async fn connect_tcp(args: ConnectTcpArgs) -> Result<()> {
433467
// log error at warn level
434468
//
435469
// we should know about it, but it's not fatal
436-
tracing::warn!("error handling connection: {}", cause);
470+
tracing::warn!("error handling connection: {:#}", cause);
437471
}
438472
});
439473
}
@@ -447,6 +481,7 @@ async fn listen_tcp(args: ListenTcpArgs) -> Result<()> {
447481
Err(e) => snafu::whatever!("invalid host string {}: {}", args.host, e),
448482
};
449483
let secret_key = get_or_create_secret()?;
484+
let authorized_keys = args.accept.authorized_keys().await?;
450485
let mut builder = Endpoint::builder()
451486
.alpns(vec![args.common.alpn()?])
452487
.secret_key(secret_key);
@@ -457,13 +492,15 @@ async fn listen_tcp(args: ListenTcpArgs) -> Result<()> {
457492
builder = builder.bind_addr_v6(addr);
458493
}
459494
let endpoint = builder.bind().await?;
495+
eprintln!("endpoint bound with node id {}", endpoint.node_id());
460496
// wait for the endpoint to figure out its address before making a ticket
461497
endpoint.home_relay().initialized().await?;
462498
let node_addr = endpoint.node_addr().initialized().await?;
463499
let mut short = node_addr.clone();
464500
let ticket = NodeTicket::new(node_addr);
465501
short.direct_addresses.clear();
466502
let short = NodeTicket::new(short);
503+
println!("ticket {short:?}");
467504

468505
// print the ticket on stderr so it doesn't interfere with the data itself
469506
//
@@ -474,23 +511,27 @@ async fn listen_tcp(args: ListenTcpArgs) -> Result<()> {
474511
if args.common.verbose > 0 {
475512
eprintln!("or:\ndumbpipe connect-tcp {short}");
476513
}
477-
tracing::info!("node id is {}", ticket.node_addr().node_id);
478-
tracing::info!("derp url is {:?}", ticket.node_addr().relay_url);
514+
info!("node id is {}", ticket.node_addr().node_id);
515+
info!("derp url is {:?}", ticket.node_addr().relay_url);
479516

480517
// handle a new incoming connection on the magic endpoint
481518
async fn handle_magic_accept(
482519
connecting: Connecting,
483520
addrs: Vec<std::net::SocketAddr>,
484521
handshake: bool,
522+
authorized_keys: Option<AuthorizedKeys>,
485523
) -> Result<()> {
486524
let connection = connecting.await.context("error accepting connection")?;
487525
let remote_node_id = &connection.remote_node_id()?;
488-
tracing::info!("got connection from {}", remote_node_id);
526+
info!("got connection from {}", remote_node_id);
527+
if let Some(ref authorized_keys) = authorized_keys {
528+
authorized_keys.authorize(&connection)?;
529+
}
489530
let (s, mut r) = connection
490531
.accept_bi()
491532
.await
492533
.context("error accepting stream")?;
493-
tracing::info!("accepted bidi stream from {}", remote_node_id);
534+
info!("accepted bidi stream from {}", remote_node_id);
494535
if handshake {
495536
// read the handshake and verify it
496537
let mut buf = [0u8; dumbpipe::HANDSHAKE.len()];
@@ -521,8 +562,11 @@ async fn listen_tcp(args: ListenTcpArgs) -> Result<()> {
521562
};
522563
let addrs = addrs.clone();
523564
let handshake = !args.common.is_custom_alpn();
565+
let authorized_keys = authorized_keys.clone();
524566
tokio::spawn(async move {
525-
if let Err(cause) = handle_magic_accept(connecting, addrs, handshake).await {
567+
if let Err(cause) =
568+
handle_magic_accept(connecting, addrs, handshake, authorized_keys).await
569+
{
526570
// log error at warn level
527571
//
528572
// we should know about it, but it's not fatal
@@ -533,6 +577,42 @@ async fn listen_tcp(args: ListenTcpArgs) -> Result<()> {
533577
Ok(())
534578
}
535579

580+
#[derive(Debug, Clone)]
581+
struct AuthorizedKeys(Arc<Vec<NodeId>>);
582+
583+
impl AuthorizedKeys {
584+
async fn load(path: impl AsRef<Path>) -> Result<Self> {
585+
let path = path.as_ref();
586+
let keys: Result<Vec<NodeId>> = tokio::fs::read_to_string(path)
587+
.await
588+
.with_context(|| format!("failed to read authorized keys file at {}", path.display()))?
589+
.lines()
590+
.filter_map(|line| line.split_whitespace().next())
591+
.filter(|str| !str.starts_with('#'))
592+
.map(|str| {
593+
NodeId::from_str(str).with_context(|| {
594+
format!("failed to parse node id `{str}` from authorized keys file")
595+
})
596+
})
597+
.collect();
598+
Ok(Self(Arc::new(keys?)))
599+
}
600+
601+
fn authorize(&self, connection: &Connection) -> Result<()> {
602+
let remote = connection.remote_node_id()?;
603+
if !self.0.contains(&remote) {
604+
connection.close(403u32.into(), b"unauthorized");
605+
info!(
606+
remote = %remote.fmt_short(),
607+
"rejecting connection: unauthorized",
608+
);
609+
Err(format_err!("connection rejected: unauthorized"))
610+
} else {
611+
Ok(())
612+
}
613+
}
614+
}
615+
536616
#[tokio::main]
537617
async fn main() -> Result<()> {
538618
tracing_subscriber::fmt::init();

0 commit comments

Comments
 (0)