Skip to content

Commit 6796e59

Browse files
committed
codex: split remote websocket client from server auth (#14853)
1 parent e5de136 commit 6796e59

File tree

3 files changed

+365
-77
lines changed

3 files changed

+365
-77
lines changed

codex-rs/app-server-client/src/lib.rs

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -865,8 +865,11 @@ mod tests {
865865
use tokio::net::TcpListener;
866866
use tokio::time::Duration;
867867
use tokio::time::timeout;
868-
use tokio_tungstenite::accept_async;
868+
use tokio_tungstenite::accept_hdr_async;
869869
use tokio_tungstenite::tungstenite::Message;
870+
use tokio_tungstenite::tungstenite::handshake::server::Request as WebSocketRequest;
871+
use tokio_tungstenite::tungstenite::handshake::server::Response as WebSocketResponse;
872+
use tokio_tungstenite::tungstenite::http::header::AUTHORIZATION;
870873

871874
async fn build_test_config() -> Config {
872875
match ConfigBuilder::default().build().await {
@@ -905,6 +908,19 @@ mod tests {
905908
}
906909

907910
async fn start_test_remote_server<F, Fut>(handler: F) -> String
911+
where
912+
F: FnOnce(tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>) -> Fut
913+
+ Send
914+
+ 'static,
915+
Fut: std::future::Future<Output = ()> + Send + 'static,
916+
{
917+
start_test_remote_server_with_auth(None, handler).await
918+
}
919+
920+
async fn start_test_remote_server_with_auth<F, Fut>(
921+
expected_auth_token: Option<String>,
922+
handler: F,
923+
) -> String
908924
where
909925
F: FnOnce(tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>) -> Fut
910926
+ Send
@@ -917,9 +933,23 @@ mod tests {
917933
let addr = listener.local_addr().expect("listener address");
918934
tokio::spawn(async move {
919935
let (stream, _) = listener.accept().await.expect("accept should succeed");
920-
let websocket = accept_async(stream)
921-
.await
922-
.expect("websocket upgrade should succeed");
936+
let websocket = accept_hdr_async(
937+
stream,
938+
move |request: &WebSocketRequest, response: WebSocketResponse| {
939+
let provided_auth_token = request
940+
.headers()
941+
.get(AUTHORIZATION)
942+
.and_then(|value| value.to_str().ok())
943+
.map(str::to_owned);
944+
let expected_auth_token = expected_auth_token
945+
.as_ref()
946+
.map(|token| format!("Bearer {token}"));
947+
assert_eq!(provided_auth_token, expected_auth_token);
948+
Ok(response)
949+
},
950+
)
951+
.await
952+
.expect("websocket upgrade should succeed");
923953
handler(websocket).await;
924954
});
925955
format!("ws://{addr}")
@@ -987,6 +1017,7 @@ mod tests {
9871017
fn test_remote_connect_args(websocket_url: String) -> RemoteAppServerConnectArgs {
9881018
RemoteAppServerConnectArgs {
9891019
websocket_url,
1020+
auth_token: None,
9901021
client_name: "codex-app-server-client-test".to_string(),
9911022
client_version: "0.0.0-test".to_string(),
9921023
experimental_api: true,
@@ -1117,6 +1148,7 @@ mod tests {
11171148
}),
11181149
)
11191150
.await;
1151+
websocket.close(None).await.expect("close should succeed");
11201152
})
11211153
.await;
11221154
let client = RemoteAppServerClient::connect(test_remote_connect_args(websocket_url))
@@ -1137,6 +1169,58 @@ mod tests {
11371169
client.shutdown().await.expect("shutdown should complete");
11381170
}
11391171

1172+
async fn remote_connect_includes_auth_header_when_configured() {
1173+
let auth_token = "remote-bearer-token".to_string();
1174+
let websocket_url = start_test_remote_server_with_auth(
1175+
Some(auth_token.clone()),
1176+
|mut websocket| async move {
1177+
expect_remote_initialize(&mut websocket).await;
1178+
websocket.close(None).await.expect("close should succeed");
1179+
},
1180+
)
1181+
.await;
1182+
let client = RemoteAppServerClient::connect(RemoteAppServerConnectArgs {
1183+
auth_token: Some(auth_token),
1184+
..test_remote_connect_args(websocket_url)
1185+
})
1186+
.await
1187+
.expect("remote client should connect");
1188+
1189+
client.shutdown().await.expect("shutdown should complete");
1190+
}
1191+
1192+
#[tokio::test]
1193+
async fn remote_connect_rejects_non_loopback_ws_when_auth_configured() {
1194+
let result = RemoteAppServerClient::connect(RemoteAppServerConnectArgs {
1195+
websocket_url: "ws://example.com:4500".to_string(),
1196+
auth_token: Some("remote-bearer-token".to_string()),
1197+
..test_remote_connect_args("ws://127.0.0.1:1".to_string())
1198+
})
1199+
.await;
1200+
let err = match result {
1201+
Ok(_) => panic!("non-loopback ws should be rejected before connect"),
1202+
Err(err) => err,
1203+
};
1204+
assert_eq!(err.kind(), ErrorKind::InvalidInput);
1205+
assert!(
1206+
err.to_string()
1207+
.contains("remote auth tokens require `wss://` or loopback `ws://` URLs")
1208+
);
1209+
}
1210+
1211+
#[test]
1212+
fn remote_auth_token_transport_policy_allows_wss_and_loopback_ws() {
1213+
assert!(crate::remote::websocket_url_supports_auth_token(
1214+
&url::Url::parse("wss://example.com:443").expect("wss URL should parse")
1215+
));
1216+
assert!(crate::remote::websocket_url_supports_auth_token(
1217+
&url::Url::parse("ws://127.0.0.1:4500").expect("loopback ws URL should parse")
1218+
));
1219+
assert!(!crate::remote::websocket_url_supports_auth_token(
1220+
&url::Url::parse("ws://example.com:4500").expect("non-loopback ws URL should parse")
1221+
));
1222+
}
1223+
11401224
#[tokio::test]
11411225
async fn remote_duplicate_request_id_keeps_original_waiter() {
11421226
let (first_request_seen_tx, first_request_seen_rx) = tokio::sync::oneshot::channel();

codex-rs/app-server-client/src/remote.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ use tokio_tungstenite::MaybeTlsStream;
4747
use tokio_tungstenite::WebSocketStream;
4848
use tokio_tungstenite::connect_async;
4949
use tokio_tungstenite::tungstenite::Message;
50+
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
51+
use tokio_tungstenite::tungstenite::http::HeaderValue;
52+
use tokio_tungstenite::tungstenite::http::header::AUTHORIZATION;
5053
use tracing::warn;
5154
use url::Url;
5255

@@ -56,6 +59,7 @@ const INITIALIZE_TIMEOUT: Duration = Duration::from_secs(10);
5659
#[derive(Debug, Clone)]
5760
pub struct RemoteAppServerConnectArgs {
5861
pub websocket_url: String,
62+
pub auth_token: Option<String>,
5963
pub client_name: String,
6064
pub client_version: String,
6165
pub experimental_api: bool,
@@ -85,6 +89,16 @@ impl RemoteAppServerConnectArgs {
8589
}
8690
}
8791

92+
pub(crate) fn websocket_url_supports_auth_token(url: &Url) -> bool {
93+
match (url.scheme(), url.host()) {
94+
("wss", Some(_)) => true,
95+
("ws", Some(url::Host::Domain(domain))) => domain.eq_ignore_ascii_case("localhost"),
96+
("ws", Some(url::Host::Ipv4(addr))) => addr.is_loopback(),
97+
("ws", Some(url::Host::Ipv6(addr))) => addr.is_loopback(),
98+
_ => false,
99+
}
100+
}
101+
88102
enum RemoteClientCommand {
89103
Request {
90104
request: Box<ClientRequest>,
@@ -131,7 +145,31 @@ impl RemoteAppServerClient {
131145
format!("invalid websocket URL `{websocket_url}`: {err}"),
132146
)
133147
})?;
134-
let stream = timeout(CONNECT_TIMEOUT, connect_async(url.as_str()))
148+
if args.auth_token.is_some() && !websocket_url_supports_auth_token(&url) {
149+
return Err(IoError::new(
150+
ErrorKind::InvalidInput,
151+
format!(
152+
"remote auth tokens require `wss://` or loopback `ws://` URLs; got `{websocket_url}`"
153+
),
154+
));
155+
}
156+
let mut request = url.as_str().into_client_request().map_err(|err| {
157+
IoError::new(
158+
ErrorKind::InvalidInput,
159+
format!("invalid websocket URL `{websocket_url}`: {err}"),
160+
)
161+
})?;
162+
if let Some(auth_token) = args.auth_token.as_deref() {
163+
let header_value =
164+
HeaderValue::from_str(&format!("Bearer {auth_token}")).map_err(|err| {
165+
IoError::new(
166+
ErrorKind::InvalidInput,
167+
format!("invalid remote authorization header value: {err}"),
168+
)
169+
})?;
170+
request.headers_mut().insert(AUTHORIZATION, header_value);
171+
}
172+
let stream = timeout(CONNECT_TIMEOUT, connect_async(request))
135173
.await
136174
.map_err(|_| {
137175
IoError::new(

0 commit comments

Comments
 (0)