Skip to content

Commit b53b6af

Browse files
committed
codex: split remote websocket client from server auth (#14853)
1 parent 4f28b64 commit b53b6af

File tree

3 files changed

+361
-72
lines changed

3 files changed

+361
-72
lines changed

codex-rs/app-server-client/src/lib.rs

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -866,8 +866,11 @@ mod tests {
866866
use tokio::net::TcpListener;
867867
use tokio::time::Duration;
868868
use tokio::time::timeout;
869-
use tokio_tungstenite::accept_async;
869+
use tokio_tungstenite::accept_hdr_async;
870870
use tokio_tungstenite::tungstenite::Message;
871+
use tokio_tungstenite::tungstenite::handshake::server::Request as WebSocketRequest;
872+
use tokio_tungstenite::tungstenite::handshake::server::Response as WebSocketResponse;
873+
use tokio_tungstenite::tungstenite::http::header::AUTHORIZATION;
871874

872875
async fn build_test_config() -> Config {
873876
match ConfigBuilder::default().build().await {
@@ -906,6 +909,19 @@ mod tests {
906909
}
907910

908911
async fn start_test_remote_server<F, Fut>(handler: F) -> String
912+
where
913+
F: FnOnce(tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>) -> Fut
914+
+ Send
915+
+ 'static,
916+
Fut: std::future::Future<Output = ()> + Send + 'static,
917+
{
918+
start_test_remote_server_with_auth(None, handler).await
919+
}
920+
921+
async fn start_test_remote_server_with_auth<F, Fut>(
922+
expected_auth_token: Option<String>,
923+
handler: F,
924+
) -> String
909925
where
910926
F: FnOnce(tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>) -> Fut
911927
+ Send
@@ -918,9 +934,23 @@ mod tests {
918934
let addr = listener.local_addr().expect("listener address");
919935
tokio::spawn(async move {
920936
let (stream, _) = listener.accept().await.expect("accept should succeed");
921-
let websocket = accept_async(stream)
922-
.await
923-
.expect("websocket upgrade should succeed");
937+
let websocket = accept_hdr_async(
938+
stream,
939+
move |request: &WebSocketRequest, response: WebSocketResponse| {
940+
let provided_auth_token = request
941+
.headers()
942+
.get(AUTHORIZATION)
943+
.and_then(|value| value.to_str().ok())
944+
.map(str::to_owned);
945+
let expected_auth_token = expected_auth_token
946+
.as_ref()
947+
.map(|token| format!("Bearer {token}"));
948+
assert_eq!(provided_auth_token, expected_auth_token);
949+
Ok(response)
950+
},
951+
)
952+
.await
953+
.expect("websocket upgrade should succeed");
924954
handler(websocket).await;
925955
});
926956
format!("ws://{addr}")
@@ -988,6 +1018,7 @@ mod tests {
9881018
fn test_remote_connect_args(websocket_url: String) -> RemoteAppServerConnectArgs {
9891019
RemoteAppServerConnectArgs {
9901020
websocket_url,
1021+
auth_token: None,
9911022
client_name: "codex-app-server-client-test".to_string(),
9921023
client_version: "0.0.0-test".to_string(),
9931024
experimental_api: true,
@@ -1114,6 +1145,7 @@ mod tests {
11141145
}),
11151146
)
11161147
.await;
1148+
websocket.close(None).await.expect("close should succeed");
11171149
})
11181150
.await;
11191151
let client = RemoteAppServerClient::connect(test_remote_connect_args(websocket_url))
@@ -1134,6 +1166,58 @@ mod tests {
11341166
client.shutdown().await.expect("shutdown should complete");
11351167
}
11361168

1169+
async fn remote_connect_includes_auth_header_when_configured() {
1170+
let auth_token = "remote-bearer-token".to_string();
1171+
let websocket_url = start_test_remote_server_with_auth(
1172+
Some(auth_token.clone()),
1173+
|mut websocket| async move {
1174+
expect_remote_initialize(&mut websocket).await;
1175+
websocket.close(None).await.expect("close should succeed");
1176+
},
1177+
)
1178+
.await;
1179+
let client = RemoteAppServerClient::connect(RemoteAppServerConnectArgs {
1180+
auth_token: Some(auth_token),
1181+
..test_remote_connect_args(websocket_url)
1182+
})
1183+
.await
1184+
.expect("remote client should connect");
1185+
1186+
client.shutdown().await.expect("shutdown should complete");
1187+
}
1188+
1189+
#[tokio::test]
1190+
async fn remote_connect_rejects_non_loopback_ws_when_auth_configured() {
1191+
let result = RemoteAppServerClient::connect(RemoteAppServerConnectArgs {
1192+
websocket_url: "ws://example.com:4500".to_string(),
1193+
auth_token: Some("remote-bearer-token".to_string()),
1194+
..test_remote_connect_args("ws://127.0.0.1:1".to_string())
1195+
})
1196+
.await;
1197+
let err = match result {
1198+
Ok(_) => panic!("non-loopback ws should be rejected before connect"),
1199+
Err(err) => err,
1200+
};
1201+
assert_eq!(err.kind(), ErrorKind::InvalidInput);
1202+
assert!(
1203+
err.to_string()
1204+
.contains("remote auth tokens require `wss://` or loopback `ws://` URLs")
1205+
);
1206+
}
1207+
1208+
#[test]
1209+
fn remote_auth_token_transport_policy_allows_wss_and_loopback_ws() {
1210+
assert!(crate::remote::websocket_url_supports_auth_token(
1211+
&url::Url::parse("wss://example.com:443").expect("wss URL should parse")
1212+
));
1213+
assert!(crate::remote::websocket_url_supports_auth_token(
1214+
&url::Url::parse("ws://127.0.0.1:4500").expect("loopback ws URL should parse")
1215+
));
1216+
assert!(!crate::remote::websocket_url_supports_auth_token(
1217+
&url::Url::parse("ws://example.com:4500").expect("non-loopback ws URL should parse")
1218+
));
1219+
}
1220+
11371221
#[tokio::test]
11381222
async fn remote_duplicate_request_id_keeps_original_waiter() {
11391223
let (first_request_seen_tx, first_request_seen_rx) = tokio::sync::oneshot::channel();

codex-rs/app-server-client/src/remote.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ use tokio_tungstenite::MaybeTlsStream;
4747
use tokio_tungstenite::WebSocketStream;
4848
use tokio_tungstenite::connect_async;
4949
use tokio_tungstenite::tungstenite::Message;
50+
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
51+
use tokio_tungstenite::tungstenite::http::HeaderValue;
52+
use tokio_tungstenite::tungstenite::http::header::AUTHORIZATION;
5053
use tracing::warn;
5154
use url::Url;
5255

@@ -56,6 +59,7 @@ const INITIALIZE_TIMEOUT: Duration = Duration::from_secs(10);
5659
#[derive(Debug, Clone)]
5760
pub struct RemoteAppServerConnectArgs {
5861
pub websocket_url: String,
62+
pub auth_token: Option<String>,
5963
pub client_name: String,
6064
pub client_version: String,
6165
pub experimental_api: bool,
@@ -85,6 +89,16 @@ impl RemoteAppServerConnectArgs {
8589
}
8690
}
8791

92+
pub(crate) fn websocket_url_supports_auth_token(url: &Url) -> bool {
93+
match (url.scheme(), url.host()) {
94+
("wss", Some(_)) => true,
95+
("ws", Some(url::Host::Domain(domain))) => domain.eq_ignore_ascii_case("localhost"),
96+
("ws", Some(url::Host::Ipv4(addr))) => addr.is_loopback(),
97+
("ws", Some(url::Host::Ipv6(addr))) => addr.is_loopback(),
98+
_ => false,
99+
}
100+
}
101+
88102
enum RemoteClientCommand {
89103
Request {
90104
request: Box<ClientRequest>,
@@ -131,7 +145,31 @@ impl RemoteAppServerClient {
131145
format!("invalid websocket URL `{websocket_url}`: {err}"),
132146
)
133147
})?;
134-
let stream = timeout(CONNECT_TIMEOUT, connect_async(url.as_str()))
148+
if args.auth_token.is_some() && !websocket_url_supports_auth_token(&url) {
149+
return Err(IoError::new(
150+
ErrorKind::InvalidInput,
151+
format!(
152+
"remote auth tokens require `wss://` or loopback `ws://` URLs; got `{websocket_url}`"
153+
),
154+
));
155+
}
156+
let mut request = url.as_str().into_client_request().map_err(|err| {
157+
IoError::new(
158+
ErrorKind::InvalidInput,
159+
format!("invalid websocket URL `{websocket_url}`: {err}"),
160+
)
161+
})?;
162+
if let Some(auth_token) = args.auth_token.as_deref() {
163+
let header_value =
164+
HeaderValue::from_str(&format!("Bearer {auth_token}")).map_err(|err| {
165+
IoError::new(
166+
ErrorKind::InvalidInput,
167+
format!("invalid remote authorization header value: {err}"),
168+
)
169+
})?;
170+
request.headers_mut().insert(AUTHORIZATION, header_value);
171+
}
172+
let stream = timeout(CONNECT_TIMEOUT, connect_async(request))
135173
.await
136174
.map_err(|_| {
137175
IoError::new(

0 commit comments

Comments
 (0)