Skip to content

Commit be3ea53

Browse files
committed
Automatically infer overlay type via auth key
1 parent 5f82c19 commit be3ea53

File tree

2 files changed

+52
-58
lines changed

2 files changed

+52
-58
lines changed

cmd/wush/receive.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,11 @@ import (
2222
"github.com/coder/wush/tsserver"
2323
)
2424

25-
func logF(format string, args ...any) {
26-
fmt.Printf(format+"\n", args...)
27-
}
28-
2925
func receiveCmd() *serpent.Command {
3026
var overlayType string
3127
return &serpent.Command{
32-
Use: "receive",
28+
Use: "receive",
29+
Long: "Runs the wush server. Allows other wush CLIs to connect to this computer.",
3330
Handler: func(inv *serpent.Invocation) error {
3431
ctx := inv.Context()
3532
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
@@ -112,9 +109,10 @@ func newTSNet(direction string) (*tsnet.Server, error) {
112109
srv.Ephemeral = true
113110
srv.AuthKey = direction
114111
srv.ControlURL = "http://localhost:8080"
115-
srv.Logf = logF
112+
srv.Logf = func(format string, args ...any) {}
113+
srv.UserLogf = func(format string, args ...any) {}
116114

117-
srv.Store, err = store.New(logF, "mem:wush")
115+
srv.Store, err = store.New(func(format string, args ...any) {}, "mem:wush")
118116
if err != nil {
119117
return nil, xerrors.Errorf("create state store: %w", err)
120118
}

cmd/wush/ssh.go

Lines changed: 47 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ func sshCmd() *serpent.Command {
2525
var (
2626
authID string
2727
waitP2P bool
28-
overlayTransport string
2928
stunAddrOverride string
3029
stunAddrOverrideIP netip.Addr
3130
sshStdio bool
@@ -37,6 +36,12 @@ func sshCmd() *serpent.Command {
3736
Handler: func(inv *serpent.Invocation) error {
3837
ctx := inv.Context()
3938
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
39+
logF := func(str string, args ...any) {
40+
if sshStdio {
41+
return
42+
}
43+
fmt.Fprintf(inv.Stderr, str, args...)
44+
}
4045
if authID == "" {
4146
err := huh.NewInput().
4247
Title("Enter the receiver's Auth ID:").
@@ -67,38 +72,31 @@ func sshCmd() *serpent.Command {
6772
return fmt.Errorf("parse auth key: %w", err)
6873
}
6974

70-
if !sshStdio {
71-
fmt.Println("Auth information:")
72-
stunStr := send.Auth.ReceiverStunAddr.String()
73-
if !send.Auth.ReceiverStunAddr.IsValid() {
74-
stunStr = "Disabled"
75-
}
76-
fmt.Println("\t> Server overlay STUN address:", cliui.Code(stunStr))
77-
derpStr := "Disabled"
78-
if send.Auth.ReceiverDERPRegionID > 0 {
79-
derpStr = dm.Regions[int(send.Auth.ReceiverDERPRegionID)].RegionName
80-
}
81-
fmt.Println("\t> Server overlay DERP home: ", cliui.Code(derpStr))
82-
fmt.Println("\t> Server overlay public key: ", cliui.Code(send.Auth.ReceiverPublicKey.ShortString()))
83-
fmt.Println("\t> Server overlay auth key: ", cliui.Code(send.Auth.OverlayPrivateKey.Public().ShortString()))
75+
logF("Auth information:")
76+
stunStr := send.Auth.ReceiverStunAddr.String()
77+
if !send.Auth.ReceiverStunAddr.IsValid() {
78+
stunStr = "Disabled"
8479
}
80+
logF("\t> Server overlay STUN address: %s", cliui.Code(stunStr))
81+
derpStr := "Disabled"
82+
if send.Auth.ReceiverDERPRegionID > 0 {
83+
derpStr = dm.Regions[int(send.Auth.ReceiverDERPRegionID)].RegionName
84+
}
85+
logF("\t> Server overlay DERP home: %s", cliui.Code(derpStr))
86+
logF("\t> Server overlay public key: %s", cliui.Code(send.Auth.ReceiverPublicKey.ShortString()))
87+
logF("\t> Server overlay auth key: %s", cliui.Code(send.Auth.OverlayPrivateKey.Public().ShortString()))
8588

8689
s, err := tsserver.NewServer(ctx, logger, send)
8790
if err != nil {
8891
return err
8992
}
9093

91-
switch overlayTransport {
92-
case "derp":
93-
if send.Auth.ReceiverDERPRegionID == 0 {
94-
return errors.New("overlay type is \"derp\", but receiver is of type \"stun\"")
95-
}
94+
if send.Auth.ReceiverDERPRegionID != 0 {
9695
go send.ListenOverlayDERP(ctx)
97-
case "stun":
98-
if !send.Auth.ReceiverStunAddr.IsValid() {
99-
return errors.New("overlay type is \"stun\", but receiver is of type \"derp\"")
100-
}
96+
} else if send.Auth.ReceiverStunAddr.IsValid() {
10197
go send.ListenOverlaySTUN(ctx)
98+
} else {
99+
return errors.New("auth key provided neither DERP nor STUN")
102100
}
103101

104102
go s.ListenAndServe(ctx)
@@ -110,22 +108,22 @@ func sshCmd() *serpent.Command {
110108
ts.Logf = func(string, ...any) {}
111109
ts.UserLogf = func(string, ...any) {}
112110

113-
// fmt.Println("Bringing Wireguard up..")
111+
logF("Bringing Wireguard up..")
114112
ts.Up(ctx)
115-
// fmt.Println("Wireguard is ready!")
113+
logF("Wireguard is ready!")
116114

117115
lc, err := ts.LocalClient()
118116
if err != nil {
119117
return err
120118
}
121119

122-
ip, err := waitUntilHasPeerHasIP(ctx, lc)
120+
ip, err := waitUntilHasPeerHasIP(ctx, logF, lc)
123121
if err != nil {
124122
return err
125123
}
126124

127125
if waitP2P {
128-
err := waitUntilHasP2P(ctx, lc)
126+
err := waitUntilHasP2P(ctx, logF, lc)
129127
if err != nil {
130128
return err
131129
}
@@ -141,12 +139,6 @@ func sshCmd() *serpent.Command {
141139
Default: "",
142140
Value: serpent.StringOf(&authID),
143141
},
144-
{
145-
Flag: "overlay-transport",
146-
Description: "The transport to use on the overlay. The overlay is used to exchange Wireguard nodes between peers. In DERP mode, nodes are exchanged over public Tailscale DERPs, while STUN mode sends nodes directly over UDP.",
147-
Default: "derp",
148-
Value: serpent.EnumOf(&overlayTransport, "derp", "stun"),
149-
},
150142
{
151143
Flag: "stun-ip-override",
152144
Default: "",
@@ -158,11 +150,17 @@ func sshCmd() *serpent.Command {
158150
Default: "false",
159151
Value: serpent.BoolOf(&sshStdio),
160152
},
153+
{
154+
Flag: "wait-p2p",
155+
Description: "Waits for the connection to be p2p.",
156+
Default: "false",
157+
Value: serpent.BoolOf(&sshStdio),
158+
},
161159
},
162160
}
163161
}
164162

165-
func waitUntilHasPeerHasIP(ctx context.Context, lc *tailscale.LocalClient) (netip.Addr, error) {
163+
func waitUntilHasPeerHasIP(ctx context.Context, logF func(str string, args ...any), lc *tailscale.LocalClient) (netip.Addr, error) {
166164
for {
167165
select {
168166
case <-ctx.Done():
@@ -178,35 +176,35 @@ func waitUntilHasPeerHasIP(ctx context.Context, lc *tailscale.LocalClient) (neti
178176

179177
peers := stat.Peers()
180178
if len(peers) == 0 {
181-
// fmt.Println("No peer yet")
179+
logF("No peer yet")
182180
continue
183181
}
184182

185-
// fmt.Println("Received peer")
183+
logF("Received peer")
186184

187185
peer, ok := stat.Peer[peers[0]]
188186
if !ok {
189-
fmt.Println("have peers but not found in map (developer error)")
187+
logF("have peers but not found in map (developer error)")
190188
continue
191189
}
192190

193191
if peer.Relay == "" {
194-
fmt.Println("peer no relay")
192+
logF("peer no relay")
195193
continue
196194
}
197195

198-
// fmt.Println("Peer active with relay", cliui.Code(peer.Relay))
196+
logF("Peer active with relay %s", cliui.Code(peer.Relay))
199197

200198
if len(peer.TailscaleIPs) == 0 {
201-
fmt.Println("peer has no ips (developer error)")
199+
logF("peer has no ips (developer error)")
202200
continue
203201
}
204202

205203
return peer.TailscaleIPs[0], nil
206204
}
207205
}
208206

209-
func waitUntilHasP2P(ctx context.Context, lc *tailscale.LocalClient) error {
207+
func waitUntilHasP2P(ctx context.Context, logF func(str string, args ...any), lc *tailscale.LocalClient) error {
210208
for {
211209
select {
212210
case <-ctx.Done():
@@ -216,43 +214,41 @@ func waitUntilHasP2P(ctx context.Context, lc *tailscale.LocalClient) error {
216214

217215
stat, err := lc.Status(ctx)
218216
if err != nil {
219-
fmt.Println("error getting lc status:", err)
217+
logF("error getting lc status: %s", err)
220218
continue
221219
}
222220

223221
peers := stat.Peers()
224222
peer, ok := stat.Peer[peers[0]]
225223
if !ok {
226-
fmt.Println("no peer found in map while waiting p2p (developer error)")
224+
logF("no peer found in map while waiting p2p (developer error)")
227225
continue
228226
}
229227

230228
if peer.Relay == "" {
231-
fmt.Println("peer no relay")
229+
logF("peer no relay")
232230
continue
233231
}
234232

235-
// fmt.Println("Peer active with relay", cliui.Code(peer.Relay))
236-
237233
if len(peer.TailscaleIPs) == 0 {
238-
fmt.Println("peer has no ips (developer error)")
234+
logF("peer has no ips (developer error)")
239235
continue
240236
}
241237

242238
pingCancel, cancel := context.WithTimeout(ctx, time.Second)
243239
pong, err := lc.Ping(pingCancel, peer.TailscaleIPs[0], tailcfg.PingDisco)
244240
cancel()
245241
if err != nil {
246-
fmt.Println("ping failed:", err)
242+
logF("ping failed: %s", err)
247243
continue
248244
}
249245

250246
if pong.Endpoint == "" {
251-
fmt.Println("not p2p yet")
247+
logF("Not p2p yet")
252248
continue
253249
}
254250

255-
// fmt.Println("Peer active over p2p", cliui.Code(pong.Endpoint))
251+
logF("Peer active over p2p %s", cliui.Code(pong.Endpoint))
256252
return nil
257253
}
258254
}

0 commit comments

Comments
 (0)