Skip to content

Commit d9c72ce

Browse files
CopilotByron
andcommitted
Add localhost-only middleware to but-server
Co-authored-by: Byron <[email protected]>
1 parent b185180 commit d9c72ce

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

crates/but-server/src/lib.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ use axum::{
55
Json, Router,
66
body::Body,
77
extract::{
8-
WebSocketUpgrade,
8+
ConnectInfo, WebSocketUpgrade,
99
ws::{Message, WebSocket},
1010
},
11+
http::StatusCode,
1112
middleware::Next,
1213
response::IntoResponse,
1314
routing::{any, get},
@@ -19,6 +20,7 @@ use futures_util::{SinkExt, StreamExt as _};
1920
use gitbutler_project::ProjectId;
2021
use serde::{Deserialize, Serialize};
2122
use serde_json::json;
23+
use std::net::SocketAddr;
2224
use tokio::sync::Mutex;
2325
use tower_http::cors::{Any, CorsLayer};
2426

@@ -52,6 +54,21 @@ struct AppState {
5254
app_settings: AppSettingsWithDiskSync,
5355
}
5456

57+
/// Middleware to ensure all connections are from localhost only
58+
async fn localhost_only_middleware(
59+
ConnectInfo(addr): ConnectInfo<SocketAddr>,
60+
req: axum::extract::Request<Body>,
61+
next: Next,
62+
) -> Result<impl IntoResponse, StatusCode> {
63+
// Check if the connection is from localhost (127.0.0.1 or ::1)
64+
if addr.ip().is_loopback() {
65+
Ok(next.run(req).await)
66+
} else {
67+
tracing::warn!("Rejected non-localhost connection from: {}", addr);
68+
Err(StatusCode::FORBIDDEN)
69+
}
70+
}
71+
5572
pub async fn run() {
5673
let cors = CorsLayer::new()
5774
.allow_methods(Any)
@@ -104,6 +121,8 @@ pub async fn run() {
104121
tokio::task::spawn(next.run(req)).await.unwrap()
105122
},
106123
))
124+
// Middleware to ensure only localhost connections are accepted
125+
.layer(axum::middleware::from_fn(localhost_only_middleware))
107126
.layer(cors)
108127
.with_state(state);
109128

@@ -112,7 +131,12 @@ pub async fn run() {
112131
let url = format!("{host}:{port}");
113132
let listener = tokio::net::TcpListener::bind(&url).await.unwrap();
114133
println!("Running at {url}");
115-
axum::serve(listener, app).await.unwrap();
134+
axum::serve(
135+
listener,
136+
app.into_make_service_with_connect_info::<SocketAddr>(),
137+
)
138+
.await
139+
.unwrap();
116140
}
117141

118142
async fn post_handle_json_command(

0 commit comments

Comments
 (0)