|
5 | 5 |
|
6 | 6 | use async_trait::async_trait;
|
7 | 7 | use base64::{engine::general_purpose as b64, Engine as _};
|
| 8 | +use futures::{stream::FuturesUnordered, StreamExt}; |
8 | 9 | use serde::Serialize;
|
9 | 10 | use sha2::{Digest, Sha256};
|
10 | 11 | use std::{str::FromStr, time::Duration};
|
11 | 12 | use sysinfo::Pid;
|
12 | 13 |
|
13 | 14 | use super::{
|
14 | 15 | args::{
|
15 |
| - AuthProvider, CliCore, ExistingTunnelArgs, TunnelRenameArgs, TunnelServeArgs, |
16 |
| - TunnelServiceSubCommands, TunnelUserSubCommands, |
| 16 | + AuthProvider, CliCore, CommandShellArgs, ExistingTunnelArgs, TunnelRenameArgs, |
| 17 | + TunnelServeArgs, TunnelServiceSubCommands, TunnelUserSubCommands, |
17 | 18 | },
|
18 | 19 | CommandContext,
|
19 | 20 | };
|
20 | 21 |
|
21 | 22 | use crate::{
|
| 23 | + async_pipe::{get_socket_name, listen_socket_rw_stream, socket_stream_split}, |
22 | 24 | auth::Auth,
|
23 | 25 | constants::{APPLICATION_NAME, TUNNEL_CLI_LOCK_NAME, TUNNEL_SERVICE_LOCK_NAME},
|
24 | 26 | log,
|
@@ -120,23 +122,55 @@ impl ServiceContainer for TunnelServiceContainer {
|
120 | 122 | }
|
121 | 123 | }
|
122 | 124 |
|
123 |
| -pub async fn command_shell(ctx: CommandContext) -> Result<i32, AnyError> { |
| 125 | +pub async fn command_shell(ctx: CommandContext, args: CommandShellArgs) -> Result<i32, AnyError> { |
124 | 126 | let platform = PreReqChecker::new().verify().await?;
|
125 |
| - serve_stream( |
126 |
| - tokio::io::stdin(), |
127 |
| - tokio::io::stderr(), |
128 |
| - ServeStreamParams { |
129 |
| - log: ctx.log, |
130 |
| - launcher_paths: ctx.paths, |
131 |
| - platform, |
132 |
| - requires_auth: true, |
133 |
| - exit_barrier: ShutdownRequest::create_rx([ShutdownRequest::CtrlC]), |
134 |
| - code_server_args: (&ctx.args).into(), |
135 |
| - }, |
136 |
| - ) |
137 |
| - .await; |
| 127 | + let mut params = ServeStreamParams { |
| 128 | + log: ctx.log, |
| 129 | + launcher_paths: ctx.paths, |
| 130 | + platform, |
| 131 | + requires_auth: true, |
| 132 | + exit_barrier: ShutdownRequest::create_rx([ShutdownRequest::CtrlC]), |
| 133 | + code_server_args: (&ctx.args).into(), |
| 134 | + }; |
138 | 135 |
|
139 |
| - Ok(0) |
| 136 | + if !args.on_socket { |
| 137 | + serve_stream(tokio::io::stdin(), tokio::io::stderr(), params).await; |
| 138 | + return Ok(0); |
| 139 | + } |
| 140 | + |
| 141 | + let socket = get_socket_name(); |
| 142 | + let mut listener = listen_socket_rw_stream(&socket) |
| 143 | + .await |
| 144 | + .map_err(|e| wrap(e, "error listening on socket"))?; |
| 145 | + |
| 146 | + params |
| 147 | + .log |
| 148 | + .result(format!("Listening on {}", socket.display())); |
| 149 | + |
| 150 | + let mut servers = FuturesUnordered::new(); |
| 151 | + |
| 152 | + loop { |
| 153 | + tokio::select! { |
| 154 | + Some(_) = servers.next() => {}, |
| 155 | + socket = listener.accept() => { |
| 156 | + match socket { |
| 157 | + Ok(s) => { |
| 158 | + let (read, write) = socket_stream_split(s); |
| 159 | + servers.push(serve_stream(read, write, params.clone())); |
| 160 | + }, |
| 161 | + Err(e) => { |
| 162 | + error!(params.log, &format!("Error accepting connection: {}", e)); |
| 163 | + return Ok(1); |
| 164 | + } |
| 165 | + } |
| 166 | + }, |
| 167 | + _ = params.exit_barrier.wait() => { |
| 168 | + // wait for all servers to finish up: |
| 169 | + while (servers.next().await).is_some() { } |
| 170 | + return Ok(0); |
| 171 | + } |
| 172 | + } |
| 173 | + } |
140 | 174 | }
|
141 | 175 |
|
142 | 176 | pub async fn service(
|
|
0 commit comments