|
| 1 | +import { CancellationToken, PromiseCompletionSource, SessionRequestFailureMessage, SessionRequestMessage, SessionRequestSuccessMessage, SshChannel, SshChannelClosedEventArgs, SshChannelError, SshChannelOpeningEventArgs, SshMessage, SshRequestEventArgs, SshSession, SshSessionClosedEventArgs, SshTraceEventIds, TraceLevel } from "@microsoft/dev-tunnels-ssh"; |
| 2 | +import { ChannelFailureMessage, ChannelMessage, ChannelOpenConfirmationMessage, ChannelOpenFailureMessage, ChannelRequestMessage, ChannelSuccessMessage, SshChannelOpenFailureReason } from "@microsoft/dev-tunnels-ssh/messages/connectionMessages"; |
| 3 | + |
| 4 | +/** |
| 5 | + * Extension methods for piping sessions and channels. |
| 6 | + * |
| 7 | + * Note this class is not exported from the package. Instead, the piping APIs are exposed via |
| 8 | + * public methods on the `SshSession` and `SshChannel` classes. See those respective methods |
| 9 | + * for API documentation. |
| 10 | + */ |
| 11 | +export class PipeExtensions { |
| 12 | + public static async pipeSession(session: SshSession, toSession: SshSession): Promise<void> { |
| 13 | + if (!session) throw new TypeError('Session is required.'); |
| 14 | + if (!toSession) throw new TypeError('Target session is required'); |
| 15 | + |
| 16 | + const endCompletion = new PromiseCompletionSource<Promise<void>>(); |
| 17 | + |
| 18 | + session.onRequest((e) => { |
| 19 | + e.responsePromise = PipeExtensions.forwardSessionRequest(e, toSession, e.cancellation); |
| 20 | + }); |
| 21 | + toSession.onRequest((e) => { |
| 22 | + e.responsePromise = PipeExtensions.forwardSessionRequest(e, session, e.cancellation); |
| 23 | + }); |
| 24 | + |
| 25 | + session.onChannelOpening((e) => { |
| 26 | + if (e.isRemoteRequest) { |
| 27 | + e.openingPromise = PipeExtensions.forwardChannel(e, toSession, e.cancellation); |
| 28 | + } |
| 29 | + }); |
| 30 | + toSession.onChannelOpening((e) => { |
| 31 | + if (e.isRemoteRequest) { |
| 32 | + e.openingPromise = PipeExtensions.forwardChannel(e, session, e.cancellation); |
| 33 | + } |
| 34 | + }); |
| 35 | + |
| 36 | + session.onClosed((e) => { |
| 37 | + endCompletion.resolve(PipeExtensions.forwardSessionClose(toSession, e)); |
| 38 | + }); |
| 39 | + toSession.onClosed((e) => { |
| 40 | + endCompletion.resolve(PipeExtensions.forwardSessionClose(session, e)); |
| 41 | + }); |
| 42 | + |
| 43 | + const endPromise = await endCompletion.promise; |
| 44 | + await endPromise; |
| 45 | + } |
| 46 | + |
| 47 | + public static async pipeChannel(channel: SshChannel, toChannel: SshChannel): Promise<void> { |
| 48 | + if (!channel) throw new TypeError('Channel is required.'); |
| 49 | + if (!toChannel) throw new TypeError('Target channel is required'); |
| 50 | + |
| 51 | + const endCompletion = new PromiseCompletionSource<Promise<void>>(); |
| 52 | + let closed = false; |
| 53 | + |
| 54 | + channel.onRequest((e) => { |
| 55 | + e.responsePromise = PipeExtensions.forwardChannelRequest(e, toChannel, e.cancellation); |
| 56 | + }); |
| 57 | + toChannel.onRequest((e) => { |
| 58 | + e.responsePromise = PipeExtensions.forwardChannelRequest(e, channel, e.cancellation); |
| 59 | + }); |
| 60 | + |
| 61 | + channel.onDataReceived((data) => { |
| 62 | + void PipeExtensions.forwardData(channel, toChannel, data).catch(); |
| 63 | + }); |
| 64 | + toChannel.onDataReceived((data) => { |
| 65 | + void PipeExtensions.forwardData(toChannel, channel, data).catch(); |
| 66 | + }); |
| 67 | + |
| 68 | + channel.onEof(() => { |
| 69 | + void PipeExtensions.forwardData(channel, toChannel, Buffer.alloc(0)).catch(); |
| 70 | + }); |
| 71 | + toChannel.onEof(() => { |
| 72 | + void PipeExtensions.forwardData(toChannel, channel, Buffer.alloc(0)).catch(); |
| 73 | + }); |
| 74 | + |
| 75 | + channel.onClosed((e) => { |
| 76 | + if (!closed) { |
| 77 | + closed = true; |
| 78 | + endCompletion.resolve(PipeExtensions.forwardChannelClose(channel, toChannel, e)); |
| 79 | + } |
| 80 | + }); |
| 81 | + toChannel.onClosed((e) => { |
| 82 | + if (!closed) { |
| 83 | + closed = true; |
| 84 | + endCompletion.resolve(PipeExtensions.forwardChannelClose(toChannel, channel, e)); |
| 85 | + } |
| 86 | + }); |
| 87 | + |
| 88 | + const endTask = await endCompletion.promise; |
| 89 | + await endTask; |
| 90 | + } |
| 91 | + |
| 92 | + private static async forwardSessionRequest( |
| 93 | + e: SshRequestEventArgs<SessionRequestMessage>, |
| 94 | + toSession: SshSession, |
| 95 | + cancellation?: CancellationToken, |
| 96 | + ): Promise<SshMessage> { |
| 97 | + // `SshSession.requestResponse()` always set `wantReply` to `true` internally and waits for a |
| 98 | + // response, but since the message buffer is cached the updated `wantReply` value is not sent. |
| 99 | + // Anyway, it's better to forward a no-reply message as another no-reply message, using |
| 100 | + // `SshSession.request()` instead. |
| 101 | + if (!e.request.wantReply) { |
| 102 | + return toSession |
| 103 | + .request(e.request, cancellation) |
| 104 | + .then(() => new SessionRequestSuccessMessage()); |
| 105 | + } |
| 106 | + return toSession.requestResponse( |
| 107 | + e.request, |
| 108 | + SessionRequestSuccessMessage, |
| 109 | + SessionRequestFailureMessage, |
| 110 | + cancellation, |
| 111 | + ); |
| 112 | + } |
| 113 | + |
| 114 | + private static async forwardChannel( |
| 115 | + e: SshChannelOpeningEventArgs, |
| 116 | + toSession: SshSession, |
| 117 | + cancellation?: CancellationToken, |
| 118 | + ): Promise<ChannelMessage> { |
| 119 | + try { |
| 120 | + const toChannel = await toSession.openChannel(e.request, null, cancellation); |
| 121 | + void PipeExtensions.pipeChannel(e.channel, toChannel).catch(); |
| 122 | + return new ChannelOpenConfirmationMessage(); |
| 123 | + } catch (err) { |
| 124 | + if (!(err instanceof Error)) throw err; |
| 125 | + |
| 126 | + const failureMessage = new ChannelOpenFailureMessage(); |
| 127 | + if (err instanceof SshChannelError) { |
| 128 | + failureMessage.reasonCode = err.reason ?? SshChannelOpenFailureReason.connectFailed; |
| 129 | + } else { |
| 130 | + failureMessage.reasonCode = SshChannelOpenFailureReason.connectFailed; |
| 131 | + } |
| 132 | + |
| 133 | + failureMessage.description = err.message; |
| 134 | + return failureMessage; |
| 135 | + } |
| 136 | + } |
| 137 | + |
| 138 | + private static async forwardChannelRequest( |
| 139 | + e: SshRequestEventArgs<ChannelRequestMessage>, |
| 140 | + toChannel: SshChannel, |
| 141 | + cancellation?: CancellationToken, |
| 142 | + ): Promise<SshMessage> { |
| 143 | + e.request.recipientChannel = toChannel.remoteChannelId; |
| 144 | + const result = await toChannel.request(e.request, cancellation); |
| 145 | + return result ? new ChannelSuccessMessage() : new ChannelFailureMessage(); |
| 146 | + } |
| 147 | + |
| 148 | + private static async forwardSessionClose( |
| 149 | + session: SshSession, |
| 150 | + e: SshSessionClosedEventArgs, |
| 151 | + ): Promise<void> { |
| 152 | + return session.close(e.reason, e.message, e.error ?? undefined); |
| 153 | + } |
| 154 | + |
| 155 | + private static async forwardData( |
| 156 | + channel: SshChannel, |
| 157 | + toChannel: SshChannel, |
| 158 | + data: Buffer, |
| 159 | + ): Promise<void> { |
| 160 | + // Make a copy of the buffer before sending because SshChannel.send() is an async operation |
| 161 | + // (it may need to wait for the window to open), while the buffer will be re-used for the |
| 162 | + // next message as sson as this task yields. |
| 163 | + const buffer = Buffer.alloc(data.length); |
| 164 | + data.copy(buffer); |
| 165 | + const promise = toChannel.send(buffer, CancellationToken.None); |
| 166 | + channel.adjustWindow(buffer.length); |
| 167 | + return promise; |
| 168 | + } |
| 169 | + |
| 170 | + private static async forwardChannelClose( |
| 171 | + fromChannel: SshChannel, |
| 172 | + toChannel: SshChannel, |
| 173 | + e: SshChannelClosedEventArgs, |
| 174 | + ): Promise<void> { |
| 175 | + const message = |
| 176 | + `Piping channel closure.\n` + |
| 177 | + `Source: ${fromChannel.session} ${fromChannel}\n` + |
| 178 | + `Destination: ${toChannel.session} ${toChannel}\n`; |
| 179 | + toChannel.session.trace(TraceLevel.Verbose, SshTraceEventIds.channelClosed, message); |
| 180 | + |
| 181 | + if (e.error) { |
| 182 | + toChannel.close(e.error as any); |
| 183 | + return Promise.resolve(); |
| 184 | + } else if (e.exitSignal) { |
| 185 | + return toChannel.close(e.exitSignal, e.errorMessage); |
| 186 | + } else if (typeof e.exitStatus === 'number') { |
| 187 | + return toChannel.close(e.exitStatus); |
| 188 | + } else { |
| 189 | + return toChannel.close(); |
| 190 | + } |
| 191 | + } |
| 192 | +} |
0 commit comments