Skip to content

Commit 11302b3

Browse files
committed
feat: add ssh and rsync commands
1 parent 007a5d4 commit 11302b3

File tree

10 files changed

+336
-110
lines changed

10 files changed

+336
-110
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
dist
2+
test

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
# wush - secure shells and file transfers behind nat
1+
# wush - wireguard powered shells and file transfers behind nat
2+
3+
[![Go Reference](https://pkg.go.dev/badge/github.com/coder/wush.svg)](https://pkg.go.dev/github.com/coder/wush)

cmd/wush/main.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@ import (
88
)
99

1010
func main() {
11-
cmd := sendCmd()
11+
cmd := sshCmd()
1212
cmd.Children = []*serpent.Command{
1313
receiveCmd(),
14+
rsyncCmd(),
1415
}
1516
err := cmd.Invoke().WithOS().Run()
1617
if err != nil {

cmd/wush/rsync.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
package main
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"log/slog"
7+
"net/netip"
8+
"os/exec"
9+
"strings"
10+
11+
"github.com/charmbracelet/huh"
12+
13+
"github.com/coder/serpent"
14+
"github.com/coder/wush/cliui"
15+
"github.com/coder/wush/overlay"
16+
"github.com/coder/wush/tsserver"
17+
)
18+
19+
func rsyncCmd() *serpent.Command {
20+
var (
21+
authID string
22+
overlayTransport string
23+
stunAddrOverride string
24+
stunAddrOverrideIP netip.Addr
25+
sshStdio bool
26+
)
27+
return &serpent.Command{
28+
Use: "rsync",
29+
Long: "Runs rsync to transfer files to a " + cliui.Code("wush") + " peer. " +
30+
"Use " + cliui.Code("wush receive") + " on the computer you would like to connect to.",
31+
Handler: func(inv *serpent.Invocation) error {
32+
ctx := inv.Context()
33+
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
34+
35+
if authID == "" {
36+
err := huh.NewInput().
37+
Title("Enter your Auth ID:").
38+
Value(&authID).
39+
Run()
40+
if err != nil {
41+
return fmt.Errorf("get auth id: %w", err)
42+
}
43+
}
44+
45+
dm, err := tsserver.DERPMapTailscale(ctx)
46+
if err != nil {
47+
return err
48+
}
49+
50+
if stunAddrOverride != "" {
51+
stunAddrOverrideIP, err = netip.ParseAddr(stunAddrOverride)
52+
if err != nil {
53+
return fmt.Errorf("parse stun addr override: %w", err)
54+
}
55+
}
56+
57+
send := overlay.NewSendOverlay(logger, dm)
58+
send.STUNIPOverride = stunAddrOverrideIP
59+
60+
err = send.Auth.Parse(authID)
61+
if err != nil {
62+
return fmt.Errorf("parse auth key: %w", err)
63+
}
64+
65+
fmt.Println("Auth information:")
66+
stunStr := send.Auth.ReceiverStunAddr.String()
67+
if !send.Auth.ReceiverStunAddr.IsValid() {
68+
stunStr = "Disabled"
69+
}
70+
fmt.Println("\t> Server overlay STUN address:", cliui.Code(stunStr))
71+
derpStr := "Disabled"
72+
if send.Auth.ReceiverDERPRegionID > 0 {
73+
derpStr = dm.Regions[int(send.Auth.ReceiverDERPRegionID)].RegionName
74+
}
75+
fmt.Println("\t> Server overlay DERP home: ", cliui.Code(derpStr))
76+
fmt.Println("\t> Server overlay public key: ", cliui.Code(send.Auth.ReceiverPublicKey.ShortString()))
77+
fmt.Println("\t> Server overlay auth key: ", cliui.Code(send.Auth.OverlayPrivateKey.Public().ShortString()))
78+
79+
args := []string{
80+
"-c",
81+
"rsync --progress --stats -avz --human-readable " + fmt.Sprintf("-e=\"wush --auth-id %s --stdio --\" ", send.Auth.AuthKey()) + strings.Join(inv.Args, " "),
82+
}
83+
fmt.Println("Running: rsync", args)
84+
cmd := exec.CommandContext(ctx, "sh", args...)
85+
cmd.Stdin = inv.Stdin
86+
cmd.Stdout = inv.Stdout
87+
cmd.Stderr = inv.Stderr
88+
89+
return cmd.Run()
90+
},
91+
Options: []serpent.Option{
92+
{
93+
Flag: "auth-id",
94+
Env: "WUSH_AUTH_ID",
95+
Description: "The auth id returned by " + cliui.Code("wush receive") + ". If not provided, it will be asked for on startup.",
96+
Default: "",
97+
Value: serpent.StringOf(&authID),
98+
},
99+
{
100+
Flag: "overlay-transport",
101+
Description: "The transport to use on the overlay. The overlay is used to exchange Wireguard nodes between peers. In DERP mode, nodes are exchanged over public Tailscale DERPs, while STUN mode sends nodes directly over UDP.",
102+
Default: "derp",
103+
Value: serpent.EnumOf(&overlayTransport, "derp", "stun"),
104+
},
105+
{
106+
Flag: "stun-ip-override",
107+
Default: "",
108+
Value: serpent.StringOf(&stunAddrOverride),
109+
},
110+
{
111+
Flag: "stdio",
112+
Description: "Run SSH over stdin/stdout. This allows wush to be used as a transport for other programs, like rsync or regular ssh.",
113+
Default: "false",
114+
Value: serpent.BoolOf(&sshStdio),
115+
},
116+
},
117+
}
118+
}

cmd/wush/send.go renamed to cmd/wush/ssh.go

Lines changed: 35 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,33 @@ import (
77
"io"
88
"log/slog"
99
"net/netip"
10-
"os"
1110
"time"
1211

1312
"github.com/charmbracelet/huh"
14-
"github.com/mattn/go-isatty"
15-
"golang.org/x/crypto/ssh"
16-
"golang.org/x/term"
17-
"golang.org/x/xerrors"
1813
"tailscale.com/client/tailscale"
1914
"tailscale.com/net/netns"
2015
"tailscale.com/tailcfg"
2116

22-
"github.com/coder/coder/v2/pty"
2317
"github.com/coder/serpent"
2418
"github.com/coder/wush/cliui"
2519
"github.com/coder/wush/overlay"
2620
"github.com/coder/wush/tsserver"
2721
xssh "github.com/coder/wush/xssh"
2822
)
2923

30-
func sendCmd() *serpent.Command {
24+
func sshCmd() *serpent.Command {
3125
var (
3226
authID string
3327
waitP2P bool
3428
overlayTransport string
3529
stunAddrOverride string
3630
stunAddrOverrideIP netip.Addr
31+
sshStdio bool
3732
)
3833
return &serpent.Command{
39-
Use: "send",
34+
Use: "wush",
35+
Long: "Opens an SSH connection to a " + cliui.Code("wush") + " peer. " +
36+
"Use " + cliui.Code("wush receive") + " on the computer you would like to connect to.",
4037
Handler: func(inv *serpent.Invocation) error {
4138
ctx := inv.Context()
4239
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
@@ -71,19 +68,21 @@ func sendCmd() *serpent.Command {
7168
return fmt.Errorf("parse auth key: %w", err)
7269
}
7370

74-
fmt.Println("Auth information:")
75-
stunStr := send.Auth.ReceiverStunAddr.String()
76-
if !send.Auth.ReceiverStunAddr.IsValid() {
77-
stunStr = "Disabled"
78-
}
79-
fmt.Println("\t> Server overlay STUN address:", cliui.Code(stunStr))
80-
derpStr := "Disabled"
81-
if send.Auth.ReceiverDERPRegionID > 0 {
82-
derpStr = dm.Regions[int(send.Auth.ReceiverDERPRegionID)].RegionName
71+
if !sshStdio {
72+
fmt.Println("Auth information:")
73+
stunStr := send.Auth.ReceiverStunAddr.String()
74+
if !send.Auth.ReceiverStunAddr.IsValid() {
75+
stunStr = "Disabled"
76+
}
77+
fmt.Println("\t> Server overlay STUN address:", cliui.Code(stunStr))
78+
derpStr := "Disabled"
79+
if send.Auth.ReceiverDERPRegionID > 0 {
80+
derpStr = dm.Regions[int(send.Auth.ReceiverDERPRegionID)].RegionName
81+
}
82+
fmt.Println("\t> Server overlay DERP home: ", cliui.Code(derpStr))
83+
fmt.Println("\t> Server overlay public key: ", cliui.Code(send.Auth.ReceiverPublicKey.ShortString()))
84+
fmt.Println("\t> Server overlay auth key: ", cliui.Code(send.Auth.OverlayPrivateKey.Public().ShortString()))
8385
}
84-
fmt.Println("\t> Server overlay DERP home: ", cliui.Code(derpStr))
85-
fmt.Println("\t> Server overlay public key: ", cliui.Code(send.Auth.ReceiverPublicKey.ShortString()))
86-
fmt.Println("\t> Server overlay auth key: ", cliui.Code(send.Auth.OverlayPrivateKey.Public().ShortString()))
8786

8887
s, err := tsserver.NewServer(ctx, logger, send)
8988
if err != nil {
@@ -112,9 +111,9 @@ func sendCmd() *serpent.Command {
112111
ts.Logf = func(string, ...any) {}
113112
ts.UserLogf = func(string, ...any) {}
114113

115-
fmt.Println("Bringing Wireguard up..")
114+
// fmt.Println("Bringing Wireguard up..")
116115
ts.Up(ctx)
117-
fmt.Println("Wireguard is ready!")
116+
// fmt.Println("Wireguard is ready!")
118117

119118
lc, err := ts.LocalClient()
120119
if err != nil {
@@ -133,87 +132,13 @@ func sendCmd() *serpent.Command {
133132
}
134133
}
135134

136-
conn, err := ts.Dial(ctx, "tcp", ip.String()+":3")
137-
if err != nil {
138-
return err
139-
}
140-
141-
sshConn, channels, requests, err := ssh.NewClientConn(conn, "localhost:22", &ssh.ClientConfig{
142-
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
143-
})
144-
if err != nil {
145-
return err
146-
}
147-
148-
sshClient := ssh.NewClient(sshConn, channels, requests)
149-
sshSession, err := sshClient.NewSession()
150-
if err != nil {
151-
return err
152-
}
153-
154-
stdinFile, validIn := inv.Stdin.(*os.File)
155-
stdoutFile, validOut := inv.Stdout.(*os.File)
156-
if validIn && validOut && isatty.IsTerminal(stdinFile.Fd()) && isatty.IsTerminal(stdoutFile.Fd()) {
157-
inState, err := pty.MakeInputRaw(stdinFile.Fd())
158-
if err != nil {
159-
return err
160-
}
161-
defer func() {
162-
_ = pty.RestoreTerminal(stdinFile.Fd(), inState)
163-
}()
164-
outState, err := pty.MakeOutputRaw(stdoutFile.Fd())
165-
if err != nil {
166-
return err
167-
}
168-
defer func() {
169-
_ = pty.RestoreTerminal(stdoutFile.Fd(), outState)
170-
}()
171-
172-
windowChange := xssh.ListenWindowSize(ctx)
173-
go func() {
174-
for {
175-
select {
176-
case <-ctx.Done():
177-
return
178-
case <-windowChange:
179-
}
180-
width, height, err := term.GetSize(int(stdoutFile.Fd()))
181-
if err != nil {
182-
continue
183-
}
184-
_ = sshSession.WindowChange(height, width)
185-
}
186-
}()
187-
}
188-
189-
err = sshSession.RequestPty("xterm-256color", 128, 128, ssh.TerminalModes{})
190-
if err != nil {
191-
return xerrors.Errorf("request pty: %w", err)
192-
}
193-
194-
sshSession.Stdin = inv.Stdin
195-
sshSession.Stdout = inv.Stdout
196-
sshSession.Stderr = inv.Stderr
197-
198-
err = sshSession.Shell()
199-
if err != nil {
200-
return xerrors.Errorf("start shell: %w", err)
201-
}
202-
203-
if validOut {
204-
// Set initial window size.
205-
width, height, err := term.GetSize(int(stdoutFile.Fd()))
206-
if err == nil {
207-
_ = sshSession.WindowChange(height, width)
208-
}
209-
}
210-
211-
return sshSession.Wait()
135+
return xssh.TailnetSSH(ctx, inv, ts, ip.String()+":3", sshStdio)
212136
},
213137
Options: []serpent.Option{
214138
{
215139
Flag: "auth-id",
216-
Description: "The auth id returned by `wush receive`. If not provided, it will be asked for on startup.",
140+
Env: "WUSH_AUTH_ID",
141+
Description: "The auth id returned by " + cliui.Code("wush receive") + ". If not provided, it will be asked for on startup.",
217142
Default: "",
218143
Value: serpent.StringOf(&authID),
219144
},
@@ -228,6 +153,12 @@ func sendCmd() *serpent.Command {
228153
Default: "",
229154
Value: serpent.StringOf(&stunAddrOverride),
230155
},
156+
{
157+
Flag: "stdio",
158+
Description: "Run SSH over stdin/stdout. This allows wush to be used as a transport for other programs, like rsync or regular ssh.",
159+
Default: "false",
160+
Value: serpent.BoolOf(&sshStdio),
161+
},
231162
},
232163
}
233164
}
@@ -248,11 +179,11 @@ func waitUntilHasPeerHasIP(ctx context.Context, lc *tailscale.LocalClient) (neti
248179

249180
peers := stat.Peers()
250181
if len(peers) == 0 {
251-
fmt.Println("No peer yet")
182+
// fmt.Println("No peer yet")
252183
continue
253184
}
254185

255-
fmt.Println("Received peer")
186+
// fmt.Println("Received peer")
256187

257188
peer, ok := stat.Peer[peers[0]]
258189
if !ok {
@@ -265,7 +196,7 @@ func waitUntilHasPeerHasIP(ctx context.Context, lc *tailscale.LocalClient) (neti
265196
continue
266197
}
267198

268-
fmt.Println("Peer active with relay", cliui.Code(peer.Relay))
199+
// fmt.Println("Peer active with relay", cliui.Code(peer.Relay))
269200

270201
if len(peer.TailscaleIPs) == 0 {
271202
fmt.Println("peer has no ips (developer error)")
@@ -302,7 +233,7 @@ func waitUntilHasP2P(ctx context.Context, lc *tailscale.LocalClient) error {
302233
continue
303234
}
304235

305-
fmt.Println("Peer active with relay", cliui.Code(peer.Relay))
236+
// fmt.Println("Peer active with relay", cliui.Code(peer.Relay))
306237

307238
if len(peer.TailscaleIPs) == 0 {
308239
fmt.Println("peer has no ips (developer error)")
@@ -322,7 +253,7 @@ func waitUntilHasP2P(ctx context.Context, lc *tailscale.LocalClient) error {
322253
continue
323254
}
324255

325-
fmt.Println("Peer active over p2p", cliui.Code(pong.Endpoint))
256+
// fmt.Println("Peer active over p2p", cliui.Code(pong.Endpoint))
326257
return nil
327258
}
328259
}

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ require (
3030
golang.org/x/sys v0.24.0
3131
golang.org/x/term v0.22.0
3232
golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9
33+
gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987
3334
tailscale.com v1.70.0
3435
)
3536

@@ -215,7 +216,6 @@ require (
215216
google.golang.org/protobuf v1.34.2 // indirect
216217
gopkg.in/DataDog/dd-trace-go.v1 v1.64.0 // indirect
217218
gopkg.in/yaml.v3 v3.0.1 // indirect
218-
gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987 // indirect
219219
nhooyr.io/websocket v1.8.10 // indirect
220220
storj.io/drpc v0.0.33 // indirect
221221
)

overlay/send.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ func (s *Send) handleNextMessage(msg []byte) (resRaw []byte, _ error) {
259259
s.waitIPOnce.Do(func() {
260260
close(s.waitIP)
261261
})
262-
fmt.Println(cliui.Timestamp(time.Now()), "Received IP from peer:", s._ip.String())
262+
// fmt.Println("Received IP from peer:", s._ip.String())
263263
case messageTypeNodeUpdate:
264264
s.in <- &ovMsg.Node
265265
}

0 commit comments

Comments
 (0)