Skip to content

Commit 684abce

Browse files
committed
Fix wush ssh
1 parent 0b19548 commit 684abce

File tree

2 files changed

+5
-73
lines changed

2 files changed

+5
-73
lines changed

cmd/wush/ssh.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func sshCmd() *serpent.Command {
4040
if sshStdio {
4141
return
4242
}
43-
fmt.Fprintf(inv.Stderr, str, args...)
43+
fmt.Fprintf(inv.Stderr, str+"\n", args...)
4444
}
4545
if authID == "" {
4646
err := huh.NewInput().
@@ -82,9 +82,9 @@ func sshCmd() *serpent.Command {
8282
if send.Auth.ReceiverDERPRegionID > 0 {
8383
derpStr = dm.Regions[int(send.Auth.ReceiverDERPRegionID)].RegionName
8484
}
85-
logF("\t> Server overlay DERP home: %s", cliui.Code(derpStr))
86-
logF("\t> Server overlay public key: %s", cliui.Code(send.Auth.ReceiverPublicKey.ShortString()))
87-
logF("\t> Server overlay auth key: %s", cliui.Code(send.Auth.OverlayPrivateKey.Public().ShortString()))
85+
logF("\t> Server overlay DERP home: %s", cliui.Code(derpStr))
86+
logF("\t> Server overlay public key: %s", cliui.Code(send.Auth.ReceiverPublicKey.ShortString()))
87+
logF("\t> Server overlay auth key: %s", cliui.Code(send.Auth.OverlayPrivateKey.Public().ShortString()))
8888

8989
s, err := tsserver.NewServer(ctx, logger, send)
9090
if err != nil {

xssh/client.go

Lines changed: 1 addition & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,15 @@ package xssh
22

33
import (
44
"context"
5-
"io"
6-
"log/slog"
75
"os"
86
"strings"
9-
"sync"
10-
"time"
117

128
"github.com/coder/coder/v2/pty"
139
"github.com/coder/serpent"
1410
"github.com/mattn/go-isatty"
1511
"golang.org/x/crypto/ssh"
1612
"golang.org/x/term"
1713
"golang.org/x/xerrors"
18-
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
1914
"tailscale.com/tsnet"
2015
)
2116

@@ -49,7 +44,7 @@ func TailnetSSH(ctx context.Context, inv *serpent.Invocation, ts *tsnet.Server,
4944
sshSession.Stdout = inv.Stdout
5045
sshSession.Stderr = inv.Stderr
5146

52-
if len(inv.Args) > 0 {
47+
if len(inv.Args) > 1 {
5348
return sshSession.Run(strings.Join(inv.Args, " "))
5449
}
5550

@@ -108,66 +103,3 @@ func TailnetSSH(ctx context.Context, inv *serpent.Invocation, ts *tsnet.Server,
108103

109104
return sshSession.Wait()
110105
}
111-
112-
type rawSSHCopier struct {
113-
conn *gonet.TCPConn
114-
logger *slog.Logger
115-
r io.Reader
116-
w io.Writer
117-
118-
done chan struct{}
119-
}
120-
121-
func newRawSSHCopier(logger *slog.Logger, conn *gonet.TCPConn, r io.Reader, w io.Writer) *rawSSHCopier {
122-
return &rawSSHCopier{conn: conn, logger: logger, r: r, w: w, done: make(chan struct{})}
123-
}
124-
125-
func (c *rawSSHCopier) copy(wg *sync.WaitGroup) {
126-
defer close(c.done)
127-
logCtx := context.Background()
128-
wg.Add(1)
129-
go func() {
130-
defer wg.Done()
131-
// We close connections using CloseWrite instead of Close, so that the SSH server sees the
132-
// closed connection while reading, and shuts down cleanly. This will trigger the io.Copy
133-
// in the server-to-client direction to also be closed and the copy() routine will exit.
134-
// This ensures that we don't leave any state in the server, like forwarded ports if
135-
// copy() were to return and the underlying tailnet connection torn down before the TCP
136-
// session exits. This is a bit of a hack to block shut down at the application layer, since
137-
// we can't serialize the TCP and tailnet layers shutting down.
138-
//
139-
// Of course, if the underlying transport is broken, io.Copy will still return.
140-
defer func() {
141-
cwErr := c.conn.CloseWrite()
142-
c.logger.DebugContext(logCtx, "closed raw SSH connection for writing", "err", cwErr)
143-
}()
144-
145-
_, err := io.Copy(c.conn, c.r)
146-
if err != nil {
147-
c.logger.ErrorContext(logCtx, "copy stdin error", "err", err)
148-
} else {
149-
c.logger.DebugContext(logCtx, "copy stdin complete")
150-
}
151-
}()
152-
_, err := io.Copy(c.w, c.conn)
153-
if err != nil {
154-
c.logger.ErrorContext(logCtx, "copy stdout error", "err", err)
155-
} else {
156-
c.logger.DebugContext(logCtx, "copy stdout complete")
157-
}
158-
}
159-
160-
func (c *rawSSHCopier) Close() error {
161-
err := c.conn.CloseWrite()
162-
163-
// give the copy() call a chance to return on a timeout, so that we don't
164-
// continue tearing down and close the underlying netstack before the SSH
165-
// session has a chance to gracefully shut down.
166-
t := time.NewTimer(5 * time.Second)
167-
defer t.Stop()
168-
select {
169-
case <-c.done:
170-
case <-t.C:
171-
}
172-
return err
173-
}

0 commit comments

Comments
 (0)