@@ -5,9 +5,10 @@ use axum::{
55 Json , Router ,
66 body:: Body ,
77 extract:: {
8- WebSocketUpgrade ,
8+ ConnectInfo , WebSocketUpgrade ,
99 ws:: { Message , WebSocket } ,
1010 } ,
11+ http:: StatusCode ,
1112 middleware:: Next ,
1213 response:: IntoResponse ,
1314 routing:: { any, get} ,
@@ -19,6 +20,7 @@ use futures_util::{SinkExt, StreamExt as _};
1920use gitbutler_project:: ProjectId ;
2021use serde:: { Deserialize , Serialize } ;
2122use serde_json:: json;
23+ use std:: net:: SocketAddr ;
2224use tokio:: sync:: Mutex ;
2325use tower_http:: cors:: { Any , CorsLayer } ;
2426
@@ -52,6 +54,21 @@ struct AppState {
5254 app_settings : AppSettingsWithDiskSync ,
5355}
5456
57+ /// Middleware to ensure all connections are from localhost only
58+ async fn localhost_only_middleware (
59+ ConnectInfo ( addr) : ConnectInfo < SocketAddr > ,
60+ req : axum:: extract:: Request < Body > ,
61+ next : Next ,
62+ ) -> Result < impl IntoResponse , StatusCode > {
63+ // Check if the connection is from localhost (127.0.0.1 or ::1)
64+ if addr. ip ( ) . is_loopback ( ) {
65+ Ok ( next. run ( req) . await )
66+ } else {
67+ tracing:: warn!( "Rejected non-localhost connection from: {}" , addr) ;
68+ Err ( StatusCode :: FORBIDDEN )
69+ }
70+ }
71+
5572pub async fn run ( ) {
5673 let cors = CorsLayer :: new ( )
5774 . allow_methods ( Any )
@@ -104,6 +121,8 @@ pub async fn run() {
104121 tokio:: task:: spawn ( next. run ( req) ) . await . unwrap ( )
105122 } ,
106123 ) )
124+ // Middleware to ensure only localhost connections are accepted
125+ . layer ( axum:: middleware:: from_fn ( localhost_only_middleware) )
107126 . layer ( cors)
108127 . with_state ( state) ;
109128
@@ -112,7 +131,12 @@ pub async fn run() {
112131 let url = format ! ( "{host}:{port}" ) ;
113132 let listener = tokio:: net:: TcpListener :: bind ( & url) . await . unwrap ( ) ;
114133 println ! ( "Running at {url}" ) ;
115- axum:: serve ( listener, app) . await . unwrap ( ) ;
134+ axum:: serve (
135+ listener,
136+ app. into_make_service_with_connect_info :: < SocketAddr > ( ) ,
137+ )
138+ . await
139+ . unwrap ( ) ;
116140}
117141
118142async fn post_handle_json_command (
0 commit comments