Skip to content

Commit ac813b6

Browse files
committed
Merge branch 'main' into merogge/inline-help-chat
2 parents 2a00cb2 + 8716df3 commit ac813b6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+712
-596
lines changed

cli/build.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ fn apply_build_environment_variables() {
2525
}
2626

2727
let pkg_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
28-
let mut cmd = Command::new("node");
28+
let mut cmd = Command::new(env::var("NODE_PATH").unwrap_or_else(|_| "node".to_string()));
2929
cmd.arg("../build/azure-pipelines/cli/prepare.js");
3030
cmd.current_dir(&pkg_dir);
3131
cmd.env("VSCODE_CLI_PREPARE_OUTPUT", "json");

cli/src/async_pipe.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
*--------------------------------------------------------------------------------------------*/
55

66
use crate::{constants::APPLICATION_NAME, util::errors::CodeError};
7+
use async_trait::async_trait;
78
use std::path::{Path, PathBuf};
9+
use tokio::io::{AsyncRead, AsyncWrite};
10+
use tokio::net::TcpListener;
811
use uuid::Uuid;
912

1013
// todo: we could probably abstract this into some crate, if one doesn't already exist
@@ -39,7 +42,7 @@ cfg_if::cfg_if! {
3942
pipe.into_split()
4043
}
4144
} else {
42-
use tokio::{time::sleep, io::{AsyncRead, AsyncWrite, ReadBuf}};
45+
use tokio::{time::sleep, io::ReadBuf};
4346
use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions, NamedPipeClient, NamedPipeServer};
4447
use std::{time::Duration, pin::Pin, task::{Context, Poll}, io};
4548
use pin_project::pin_project;
@@ -181,3 +184,34 @@ pub fn get_socket_name() -> PathBuf {
181184
}
182185
}
183186
}
187+
188+
pub type AcceptedRW = (
189+
Box<dyn AsyncRead + Send + Unpin>,
190+
Box<dyn AsyncWrite + Send + Unpin>,
191+
);
192+
193+
#[async_trait]
194+
pub trait AsyncRWAccepter {
195+
async fn accept_rw(&mut self) -> Result<AcceptedRW, CodeError>;
196+
}
197+
198+
#[async_trait]
199+
impl AsyncRWAccepter for AsyncPipeListener {
200+
async fn accept_rw(&mut self) -> Result<AcceptedRW, CodeError> {
201+
let pipe = self.accept().await?;
202+
let (read, write) = socket_stream_split(pipe);
203+
Ok((Box::new(read), Box::new(write)))
204+
}
205+
}
206+
207+
#[async_trait]
208+
impl AsyncRWAccepter for TcpListener {
209+
async fn accept_rw(&mut self) -> Result<AcceptedRW, CodeError> {
210+
let (stream, _) = self
211+
.accept()
212+
.await
213+
.map_err(CodeError::AsyncPipeListenerFailed)?;
214+
let (read, write) = tokio::io::split(stream);
215+
Ok((Box::new(read), Box::new(write)))
216+
}
217+
}

cli/src/commands/args.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,12 @@ pub struct CommandShellArgs {
182182
/// Listen on a socket instead of stdin/stdout.
183183
#[clap(long)]
184184
pub on_socket: bool,
185+
/// Listen on a port instead of stdin/stdout.
186+
#[clap(long)]
187+
pub on_port: bool,
188+
/// Require the given token string to be given in the handshake.
189+
#[clap(long)]
190+
pub require_token: Option<String>,
185191
}
186192

187193
#[derive(Args, Debug, Clone)]

cli/src/commands/tunnels.rs

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use super::{
2020
};
2121

2222
use crate::{
23-
async_pipe::{get_socket_name, listen_socket_rw_stream, socket_stream_split},
23+
async_pipe::{get_socket_name, listen_socket_rw_stream, AsyncRWAccepter},
2424
auth::Auth,
2525
constants::{APPLICATION_NAME, TUNNEL_CLI_LOCK_NAME, TUNNEL_SERVICE_LOCK_NAME},
2626
log,
@@ -35,7 +35,7 @@ use crate::{
3535
singleton_server::{
3636
make_singleton_server, start_singleton_server, BroadcastLogSink, SingletonServerArgs,
3737
},
38-
Next, ServeStreamParams, ServiceContainer, ServiceManager,
38+
AuthRequired, Next, ServeStreamParams, ServiceContainer, ServiceManager,
3939
},
4040
util::{
4141
app_lock::AppMutex,
@@ -128,36 +128,52 @@ pub async fn command_shell(ctx: CommandContext, args: CommandShellArgs) -> Resul
128128
log: ctx.log,
129129
launcher_paths: ctx.paths,
130130
platform,
131-
requires_auth: true,
131+
requires_auth: args
132+
.require_token
133+
.map(AuthRequired::VSDAWithToken)
134+
.unwrap_or(AuthRequired::VSDA),
132135
exit_barrier: ShutdownRequest::create_rx([ShutdownRequest::CtrlC]),
133136
code_server_args: (&ctx.args).into(),
134137
};
135138

136-
if !args.on_socket {
137-
serve_stream(tokio::io::stdin(), tokio::io::stderr(), params).await;
138-
return Ok(0);
139-
}
139+
let mut listener: Box<dyn AsyncRWAccepter> = match (args.on_port, args.on_socket) {
140+
(_, true) => {
141+
let socket = get_socket_name();
142+
let listener = listen_socket_rw_stream(&socket)
143+
.await
144+
.map_err(|e| wrap(e, "error listening on socket"))?;
140145

141-
let socket = get_socket_name();
142-
let mut listener = listen_socket_rw_stream(&socket)
143-
.await
144-
.map_err(|e| wrap(e, "error listening on socket"))?;
146+
params
147+
.log
148+
.result(format!("Listening on {}", socket.display()));
149+
150+
Box::new(listener)
151+
}
152+
(true, _) => {
153+
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
154+
.await
155+
.map_err(|e| wrap(e, "error listening on port"))?;
145156

146-
params
147-
.log
148-
.result(format!("Listening on {}", socket.display()));
157+
params
158+
.log
159+
.result(format!("Listening on {}", listener.local_addr().unwrap()));
160+
161+
Box::new(listener)
162+
}
163+
_ => {
164+
serve_stream(tokio::io::stdin(), tokio::io::stderr(), params).await;
165+
return Ok(0);
166+
}
167+
};
149168

150169
let mut servers = FuturesUnordered::new();
151170

152171
loop {
153172
tokio::select! {
154173
Some(_) = servers.next() => {},
155-
socket = listener.accept() => {
174+
socket = listener.accept_rw() => {
156175
match socket {
157-
Ok(s) => {
158-
let (read, write) = socket_stream_split(s);
159-
servers.push(serve_stream(read, write, params.clone()));
160-
},
176+
Ok((read, write)) => servers.push(serve_stream(read, write, params.clone())),
161177
Err(e) => {
162178
error!(params.log, &format!("Error accepting connection: {}", e));
163179
return Ok(1);

cli/src/msgpack_rpc.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ pub struct MsgPackCodec<T> {
122122
impl<T> MsgPackCodec<T> {
123123
pub fn new() -> Self {
124124
Self {
125-
_marker: std::marker::PhantomData::default(),
125+
_marker: std::marker::PhantomData,
126126
}
127127
}
128128
}

cli/src/tunnels.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ mod service_macos;
3434
mod service_windows;
3535
mod socket_signal;
3636

37-
pub use control_server::{serve, serve_stream, Next, ServeStreamParams};
37+
pub use control_server::{serve, serve_stream, Next, ServeStreamParams, AuthRequired};
3838
pub use nosleep::SleepInhibitor;
3939
pub use service::{
4040
create_service_manager, ServiceContainer, ServiceManager, SERVICE_LOG_FILE_NAME,

cli/src/tunnels/control_server.rs

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@ use super::dev_tunnels::ActiveTunnel;
4848
use super::paths::prune_stopped_servers;
4949
use super::port_forwarder::{PortForwarding, PortForwardingProcessor};
5050
use super::protocol::{
51-
AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ChallengeIssueResponse,
52-
ChallengeVerifyParams, ClientRequestMethod, EmptyObject, ForwardParams, ForwardResult,
53-
FsStatRequest, FsStatResponse, GetEnvResponse, GetHostnameResponse, HttpBodyParams,
54-
HttpHeadersParams, ServeParams, ServerLog, ServerMessageParams, SpawnParams, SpawnResult,
55-
ToClientRequest, UnforwardParams, UpdateParams, UpdateResult, VersionResponse,
51+
AcquireCliParams, CallServerHttpParams, CallServerHttpResult, ChallengeIssueParams,
52+
ChallengeIssueResponse, ChallengeVerifyParams, ClientRequestMethod, EmptyObject, ForwardParams,
53+
ForwardResult, FsStatRequest, FsStatResponse, GetEnvResponse, GetHostnameResponse,
54+
HttpBodyParams, HttpHeadersParams, ServeParams, ServerLog, ServerMessageParams, SpawnParams,
55+
SpawnResult, ToClientRequest, UnforwardParams, UpdateParams, UpdateResult, VersionResponse,
5656
METHOD_CHALLENGE_VERIFY,
5757
};
5858
use super::server_bridge::ServerBridge;
@@ -94,8 +94,8 @@ struct HandlerContext {
9494

9595
/// Handler auth state.
9696
enum AuthState {
97-
/// Auth is required, we're waiting for the client to send its challenge.
98-
WaitingForChallenge,
97+
/// Auth is required, we're waiting for the client to send its challenge optionally bearing a token.
98+
WaitingForChallenge(Option<String>),
9999
/// A challenge has been issued. Waiting for a verification.
100100
ChallengeIssued(String),
101101
/// Auth is no longer required.
@@ -215,7 +215,7 @@ pub async fn serve(
215215
code_server_args: own_code_server_args,
216216
platform,
217217
exit_barrier: own_exit,
218-
requires_auth: false,
218+
requires_auth: AuthRequired::None,
219219
}).with_context(cx.clone()).await;
220220

221221
cx.span().add_event(
@@ -233,13 +233,20 @@ pub async fn serve(
233233
}
234234
}
235235

236+
#[derive(Clone)]
237+
pub enum AuthRequired {
238+
None,
239+
VSDA,
240+
VSDAWithToken(String),
241+
}
242+
236243
#[derive(Clone)]
237244
pub struct ServeStreamParams {
238245
pub log: log::Logger,
239246
pub launcher_paths: LauncherPaths,
240247
pub code_server_args: CodeServerArgs,
241248
pub platform: Platform,
242-
pub requires_auth: bool,
249+
pub requires_auth: AuthRequired,
243250
pub exit_barrier: Barrier<ShutdownSignal>,
244251
}
245252

@@ -269,16 +276,17 @@ fn make_socket_rpc(
269276
launcher_paths: LauncherPaths,
270277
code_server_args: CodeServerArgs,
271278
port_forwarding: Option<PortForwarding>,
272-
requires_auth: bool,
279+
requires_auth: AuthRequired,
273280
platform: Platform,
274281
) -> RpcDispatcher<MsgPackSerializer, HandlerContext> {
275282
let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new()));
276283
let server_bridges = ServerMultiplexer::new();
277284
let mut rpc = RpcBuilder::new(MsgPackSerializer {}).methods(HandlerContext {
278285
did_update: Arc::new(AtomicBool::new(false)),
279286
auth_state: Arc::new(std::sync::Mutex::new(match requires_auth {
280-
true => AuthState::WaitingForChallenge,
281-
false => AuthState::Authenticated,
287+
AuthRequired::VSDAWithToken(t) => AuthState::WaitingForChallenge(Some(t)),
288+
AuthRequired::VSDA => AuthState::WaitingForChallenge(None),
289+
AuthRequired::None => AuthState::Authenticated,
282290
})),
283291
socket_tx,
284292
log: log.clone(),
@@ -305,8 +313,8 @@ fn make_socket_rpc(
305313
ensure_auth(&c.auth_state)?;
306314
handle_get_env()
307315
});
308-
rpc.register_sync(METHOD_CHALLENGE_ISSUE, |_: EmptyObject, c| {
309-
handle_challenge_issue(&c.auth_state)
316+
rpc.register_sync(METHOD_CHALLENGE_ISSUE, |p: ChallengeIssueParams, c| {
317+
handle_challenge_issue(p, &c.auth_state)
310318
});
311319
rpc.register_sync(METHOD_CHALLENGE_VERIFY, |p: ChallengeVerifyParams, c| {
312320
handle_challenge_verify(p.response, &c.auth_state)
@@ -423,6 +431,7 @@ async fn process_socket(
423431
let rx_counter = Arc::new(AtomicUsize::new(0));
424432
let http_requests = Arc::new(std::sync::Mutex::new(HashMap::new()));
425433

434+
let already_authed = matches!(requires_auth, AuthRequired::None);
426435
let rpc = make_socket_rpc(
427436
log.clone(),
428437
socket_tx.clone(),
@@ -440,7 +449,7 @@ async fn process_socket(
440449
let socket_tx = socket_tx.clone();
441450
let exit_barrier = exit_barrier.clone();
442451
tokio::spawn(async move {
443-
if !requires_auth {
452+
if already_authed {
444453
send_version(&socket_tx).await;
445454
}
446455

@@ -826,13 +835,22 @@ fn handle_get_env() -> Result<GetEnvResponse, AnyError> {
826835
}
827836

828837
fn handle_challenge_issue(
838+
params: ChallengeIssueParams,
829839
auth_state: &Arc<std::sync::Mutex<AuthState>>,
830840
) -> Result<ChallengeIssueResponse, AnyError> {
831841
let challenge = create_challenge();
832842

833843
let mut auth_state = auth_state.lock().unwrap();
834-
*auth_state = AuthState::ChallengeIssued(challenge.clone());
844+
if let AuthState::WaitingForChallenge(Some(s)) = &*auth_state {
845+
println!("looking for token {}, got {:?}", s, params.token);
846+
match &params.token {
847+
Some(t) if s != t => return Err(CodeError::AuthChallengeBadToken.into()),
848+
None => return Err(CodeError::AuthChallengeBadToken.into()),
849+
_ => {}
850+
}
851+
}
835852

853+
*auth_state = AuthState::ChallengeIssued(challenge.clone());
836854
Ok(ChallengeIssueResponse { challenge })
837855
}
838856

@@ -844,7 +862,7 @@ fn handle_challenge_verify(
844862

845863
match &*auth_state {
846864
AuthState::Authenticated => Ok(EmptyObject {}),
847-
AuthState::WaitingForChallenge => Err(CodeError::AuthChallengeNotIssued.into()),
865+
AuthState::WaitingForChallenge(_) => Err(CodeError::AuthChallengeNotIssued.into()),
848866
AuthState::ChallengeIssued(c) => match verify_challenge(c, &response) {
849867
false => Err(CodeError::AuthChallengeNotIssued.into()),
850868
true => {

cli/src/tunnels/dev_tunnels.rs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ use tunnels::management::{
3333

3434
use super::wsl_detect::is_wsl_installed;
3535

36+
static TUNNEL_COUNT_LIMIT_NAME: &str = "TunnelsPerUserPerLocation";
37+
3638
#[derive(Clone, Serialize, Deserialize)]
3739
pub struct PersistedTunnel {
3840
pub name: String,
@@ -458,8 +460,6 @@ impl DevTunnels {
458460

459461
self.check_is_name_free(name).await?;
460462

461-
let mut tried_recycle = false;
462-
463463
let new_tunnel = Tunnel {
464464
tags: vec![
465465
name.to_string(),
@@ -480,13 +480,14 @@ impl DevTunnels {
480480
Err(HttpError::ResponseError(e))
481481
if e.status_code == StatusCode::TOO_MANY_REQUESTS =>
482482
{
483-
if !tried_recycle && self.try_recycle_tunnel().await? {
484-
tried_recycle = true;
485-
continue;
486-
}
487-
488483
if let Some(d) = e.get_details() {
489484
let detail = d.detail.unwrap_or_else(|| "unknown".to_string());
485+
if detail.contains(TUNNEL_COUNT_LIMIT_NAME)
486+
&& self.try_recycle_tunnel().await?
487+
{
488+
continue;
489+
}
490+
490491
return Err(AnyError::from(TunnelCreationFailed(
491492
name.to_string(),
492493
detail,

cli/src/tunnels/protocol.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,11 @@ pub struct SpawnResult {
199199
pub const METHOD_CHALLENGE_ISSUE: &str = "challenge_issue";
200200
pub const METHOD_CHALLENGE_VERIFY: &str = "challenge_verify";
201201

202+
#[derive(Serialize, Deserialize)]
203+
pub struct ChallengeIssueParams {
204+
pub token: Option<String>,
205+
}
206+
202207
#[derive(Serialize, Deserialize)]
203208
pub struct ChallengeIssueResponse {
204209
pub challenge: String,

cli/src/util/errors.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,8 @@ pub enum CodeError {
509509
ServerAuthRequired,
510510
#[error("challenge not yet issued")]
511511
AuthChallengeNotIssued,
512+
#[error("challenge token is invalid")]
513+
AuthChallengeBadToken,
512514
#[error("unauthorized client refused")]
513515
AuthMismatch,
514516
#[error("keyring communication timed out after 5s")]

0 commit comments

Comments
 (0)