Skip to content

Commit 5ea153d

Browse files
authored
chore: support old tunnel URLs and clients (#3)
1 parent 0f43df6 commit 5ea153d

File tree

11 files changed

+788
-239
lines changed

11 files changed

+788
-239
lines changed

cmd/tunnel/main.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net"
99
"net/url"
1010
"os"
11+
"os/signal"
1112
"time"
1213

1314
"github.com/urfave/cli/v2"
@@ -143,6 +144,12 @@ func runApp(ctx *cli.Context) error {
143144
}
144145
}()
145146

147+
_, _ = fmt.Fprintln(os.Stderr, "Tunnel is ready. You can now connect to one of the following URLs:")
148+
_, _ = fmt.Fprintln(os.Stderr, " -", tunnel.URL.String())
149+
for _, u := range tunnel.OtherURLs {
150+
_, _ = fmt.Fprintln(os.Stderr, " -", u.String())
151+
}
152+
146153
// Start forwarding traffic to/from the tunnel.
147154
go func() {
148155
for {
@@ -183,7 +190,15 @@ func runApp(ctx *cli.Context) error {
183190

184191
_, _ = fmt.Printf("\nTunnel is ready! You can now connect to %s\n", tunnel.URL.String())
185192

186-
// TODO: manual signal handling
187-
<-tunnel.Wait()
193+
notifyCtx, notifyStop := signal.NotifyContext(ctx.Context, InterruptSignals...)
194+
defer notifyStop()
195+
196+
select {
197+
case <-notifyCtx.Done():
198+
_, _ = fmt.Printf("\nClosing tunnel due to signal...\n")
199+
return tunnel.Close()
200+
case <-tunnel.Wait():
201+
}
202+
188203
return nil
189204
}

cmd/tunnel/signal_unix.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//go:build !windows
2+
3+
package main
4+
5+
import (
6+
"os"
7+
"syscall"
8+
)
9+
10+
var InterruptSignals = []os.Signal{
11+
os.Interrupt,
12+
syscall.SIGTERM,
13+
syscall.SIGHUP,
14+
}

cmd/tunnel/signal_windows.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//go:build windows
2+
3+
package main
4+
5+
import (
6+
"os"
7+
)
8+
9+
var InterruptSignals = []os.Signal{os.Interrupt}

tunneld/api.go

Lines changed: 106 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,21 @@ package tunneld
22

33
import (
44
"context"
5+
"encoding/hex"
56
"fmt"
67
"net"
78
"net/http"
89
"net/http/httputil"
910
"net/netip"
11+
"net/url"
1012
"strings"
1113
"time"
1214

1315
"github.com/go-chi/chi"
1416
"go.opentelemetry.io/otel/attribute"
1517
"go.opentelemetry.io/otel/trace"
18+
"golang.org/x/xerrors"
19+
"golang.zx2c4.com/wireguard/device"
1620

1721
"github.com/coder/wgtunnel/tunneld/httpapi"
1822
"github.com/coder/wgtunnel/tunneld/httpmw"
@@ -32,7 +36,8 @@ func (api *API) Router() chi.Router {
3236
httpmw.RateLimit(10, 10*time.Second),
3337
)
3438

35-
r.Post("/api/v1/clients", api.postClients)
39+
r.Post("/tun", api.postTun)
40+
r.Post("/api/v2/clients", api.postClients)
3641

3742
r.NotFound(func(rw http.ResponseWriter, r *http.Request) {
3843
httpapi.Write(r.Context(), rw, http.StatusNotFound, tunnelsdk.Response{
@@ -43,36 +48,127 @@ func (api *API) Router() chi.Router {
4348
return r
4449
}
4550

51+
type LegacyPostTunRequest struct {
52+
PublicKey device.NoisePublicKey `json:"public_key"`
53+
}
54+
55+
type LegacyPostTunResponse struct {
56+
Hostname string `json:"hostname"`
57+
ServerEndpoint string `json:"server_endpoint"`
58+
ServerIP netip.Addr `json:"server_ip"`
59+
ServerPublicKey string `json:"server_public_key"` // hex
60+
ClientIP netip.Addr `json:"client_ip"`
61+
}
62+
63+
// postTun provides compatibility with the old tunnel client contained in older
64+
// versions of coder/coder. It essentially converts the old request format to a
65+
// newer request, and the newer response to the old response format.
66+
func (api *API) postTun(rw http.ResponseWriter, r *http.Request) {
67+
ctx := r.Context()
68+
69+
var req LegacyPostTunRequest
70+
if !httpapi.Read(ctx, rw, r, &req) {
71+
return
72+
}
73+
74+
registerReq := tunnelsdk.ClientRegisterRequest{
75+
Version: tunnelsdk.TunnelVersion1,
76+
PublicKey: req.PublicKey,
77+
}
78+
79+
resp, exists, err := api.registerClient(registerReq)
80+
if err != nil {
81+
httpapi.Write(ctx, rw, http.StatusInternalServerError, tunnelsdk.Response{
82+
Message: "Failed to register client.",
83+
Detail: err.Error(),
84+
})
85+
return
86+
}
87+
88+
if len(resp.TunnelURLs) == 0 {
89+
httpapi.Write(ctx, rw, http.StatusInternalServerError, tunnelsdk.Response{
90+
Message: "No tunnel URLs found.",
91+
})
92+
return
93+
}
94+
95+
u, err := url.Parse(resp.TunnelURLs[0])
96+
if err != nil {
97+
httpapi.Write(ctx, rw, http.StatusInternalServerError, tunnelsdk.Response{
98+
Message: "Failed to parse tunnel URL.",
99+
Detail: err.Error(),
100+
})
101+
return
102+
}
103+
104+
status := http.StatusCreated
105+
if exists {
106+
status = http.StatusOK
107+
}
108+
httpapi.Write(ctx, rw, status, LegacyPostTunResponse{
109+
Hostname: u.Host,
110+
ServerEndpoint: resp.ServerEndpoint,
111+
ServerIP: resp.ServerIP,
112+
ServerPublicKey: hex.EncodeToString(resp.ServerPublicKey[:]),
113+
ClientIP: resp.ClientIP,
114+
})
115+
}
116+
46117
func (api *API) postClients(rw http.ResponseWriter, r *http.Request) {
118+
ctx := r.Context()
119+
47120
var req tunnelsdk.ClientRegisterRequest
48121
if !httpapi.Read(r.Context(), rw, r, &req) {
49122
return
50123
}
51124

52-
ip := api.WireguardPublicKeyToIP(req.PublicKey)
125+
resp, _, err := api.registerClient(req)
126+
if err != nil {
127+
httpapi.Write(ctx, rw, http.StatusInternalServerError, tunnelsdk.Response{
128+
Message: "Failed to register client.",
129+
Detail: err.Error(),
130+
})
131+
return
132+
}
133+
134+
httpapi.Write(ctx, rw, http.StatusOK, resp)
135+
}
136+
137+
func (api *API) registerClient(req tunnelsdk.ClientRegisterRequest) (tunnelsdk.ClientRegisterResponse, bool, error) {
138+
if req.Version <= 0 || req.Version > tunnelsdk.TunnelVersionLatest {
139+
req.Version = tunnelsdk.TunnelVersionLatest
140+
}
141+
142+
ip, urls := api.WireguardPublicKeyToIPAndURLs(req.PublicKey, req.Version)
143+
144+
exists := true
53145
if api.wgDevice.LookupPeer(req.PublicKey) == nil {
146+
exists = false
147+
54148
err := api.wgDevice.IpcSet(fmt.Sprintf(`public_key=%x
55149
allowed_ip=%s/128`,
56150
req.PublicKey,
57151
ip.String(),
58152
))
59153
if err != nil {
60-
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, tunnelsdk.Response{
61-
Message: "Failed to register client.",
62-
Detail: err.Error(),
63-
})
64-
return
154+
return tunnelsdk.ClientRegisterResponse{}, false, xerrors.Errorf("register client with wireguard: %w", err)
65155
}
66156
}
67157

68-
httpapi.Write(r.Context(), rw, http.StatusOK, tunnelsdk.ClientRegisterResponse{
69-
TunnelURL: api.WireguardIPToTunnelURL(ip).String(),
158+
urlsStr := make([]string, len(urls))
159+
for i, u := range urls {
160+
urlsStr[i] = u.String()
161+
}
162+
163+
return tunnelsdk.ClientRegisterResponse{
164+
Version: req.Version,
165+
TunnelURLs: urlsStr,
70166
ClientIP: ip,
71167
ServerEndpoint: api.WireguardEndpoint,
72168
ServerIP: api.WireguardServerIP,
73169
ServerPublicKey: api.WireguardKey.NoisePublicKey(),
74170
WireguardMTU: api.WireguardMTU,
75-
})
171+
}, exists, nil
76172
}
77173

78174
func (api *API) handleTunnelMW(next http.Handler) http.Handler {

tunneld/api_test.go

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,73 @@ package tunneld_test
22

33
import (
44
"context"
5+
"encoding/hex"
6+
"encoding/json"
7+
"net/http"
8+
"strings"
59
"testing"
610

711
"github.com/stretchr/testify/require"
812

13+
"github.com/coder/wgtunnel/tunneld"
914
"github.com/coder/wgtunnel/tunnelsdk"
1015
)
1116

17+
// Test for the compatibility endpoint which allows old tunnels to connect to
18+
// the new server.
19+
func Test_postTun(t *testing.T) {
20+
t.Parallel()
21+
22+
td, client := createTestTunneld(t, nil)
23+
24+
key, err := tunnelsdk.GeneratePrivateKey()
25+
require.NoError(t, err)
26+
27+
expectedIP, expectedURLs := td.WireguardPublicKeyToIPAndURLs(key.NoisePublicKey(), tunnelsdk.TunnelVersion1)
28+
require.Len(t, expectedURLs, 2)
29+
require.Len(t, strings.Split(expectedURLs[0].Host, ".")[0], 32)
30+
expectedHostname := expectedURLs[0].Host
31+
32+
// First request should return a 201.
33+
resp, err := client.Request(context.Background(), http.MethodPost, "/tun", tunneld.LegacyPostTunRequest{
34+
PublicKey: key.NoisePublicKey(),
35+
})
36+
require.NoError(t, err)
37+
defer resp.Body.Close()
38+
require.Equal(t, http.StatusCreated, resp.StatusCode)
39+
40+
var legacyRes tunneld.LegacyPostTunResponse
41+
require.NoError(t, json.NewDecoder(resp.Body).Decode(&legacyRes))
42+
require.Equal(t, expectedIP, legacyRes.ClientIP)
43+
require.Equal(t, expectedHostname, legacyRes.Hostname)
44+
45+
// Register on the new endpoint so we can compare the values to the legacy
46+
// endpoint.
47+
newRes, err := client.ClientRegister(context.Background(), tunnelsdk.ClientRegisterRequest{
48+
Version: tunnelsdk.TunnelVersion1,
49+
PublicKey: key.NoisePublicKey(),
50+
})
51+
require.NoError(t, err)
52+
require.Equal(t, tunnelsdk.TunnelVersion1, newRes.Version)
53+
54+
require.Equal(t, legacyRes.ServerEndpoint, newRes.ServerEndpoint)
55+
require.Equal(t, legacyRes.ServerIP, newRes.ServerIP)
56+
require.Equal(t, legacyRes.ServerPublicKey, hex.EncodeToString(newRes.ServerPublicKey[:]))
57+
require.Equal(t, legacyRes.ClientIP, newRes.ClientIP)
58+
59+
// Second request should return a 200.
60+
resp, err = client.Request(context.Background(), http.MethodPost, "/tun", tunneld.LegacyPostTunRequest{
61+
PublicKey: key.NoisePublicKey(),
62+
})
63+
require.NoError(t, err)
64+
defer resp.Body.Close()
65+
require.Equal(t, http.StatusOK, resp.StatusCode)
66+
67+
var legacyRes2 tunneld.LegacyPostTunResponse
68+
require.NoError(t, json.NewDecoder(resp.Body).Decode(&legacyRes2))
69+
require.Equal(t, legacyRes, legacyRes2)
70+
}
71+
1272
func Test_postClients(t *testing.T) {
1373
t.Parallel()
1474

@@ -17,16 +77,22 @@ func Test_postClients(t *testing.T) {
1777
key, err := tunnelsdk.GeneratePrivateKey()
1878
require.NoError(t, err)
1979

20-
expectedIP := td.WireguardPublicKeyToIP(key.NoisePublicKey())
21-
expectedURL := td.WireguardIPToTunnelURL(expectedIP)
80+
expectedIP, expectedURLs := td.WireguardPublicKeyToIPAndURLs(key.NoisePublicKey(), tunnelsdk.TunnelVersion2)
81+
82+
expectedURLsStr := make([]string, len(expectedURLs))
83+
for i, u := range expectedURLs {
84+
expectedURLsStr[i] = u.String()
85+
}
2286

2387
// Register a client.
2488
res, err := client.ClientRegister(context.Background(), tunnelsdk.ClientRegisterRequest{
89+
// No version should default to 2.
2590
PublicKey: key.NoisePublicKey(),
2691
})
2792
require.NoError(t, err)
2893

29-
require.Equal(t, expectedURL.String(), res.TunnelURL)
94+
require.Equal(t, tunnelsdk.TunnelVersion2, res.Version)
95+
require.Equal(t, expectedURLsStr, res.TunnelURLs)
3096
require.Equal(t, expectedIP, res.ClientIP)
3197
require.Equal(t, td.WireguardEndpoint, res.ServerEndpoint)
3298
require.Equal(t, td.WireguardServerIP, res.ServerIP)
@@ -35,8 +101,22 @@ func Test_postClients(t *testing.T) {
35101

36102
// Register the same client again.
37103
res2, err := client.ClientRegister(context.Background(), tunnelsdk.ClientRegisterRequest{
104+
Version: tunnelsdk.TunnelVersion2,
38105
PublicKey: key.NoisePublicKey(),
39106
})
40107
require.NoError(t, err)
41108
require.Equal(t, res, res2)
109+
110+
// Register the same client with the old version.
111+
res3, err := client.ClientRegister(context.Background(), tunnelsdk.ClientRegisterRequest{
112+
Version: tunnelsdk.TunnelVersion1,
113+
PublicKey: key.NoisePublicKey(),
114+
})
115+
require.NoError(t, err)
116+
117+
// Should be equal after reversing the URL list.
118+
require.Equal(t, tunnelsdk.TunnelVersion1, res3.Version)
119+
res3.TunnelURLs[0], res3.TunnelURLs[1] = res3.TunnelURLs[1], res3.TunnelURLs[0]
120+
res3.Version = tunnelsdk.TunnelVersion2
121+
require.Equal(t, res, res3)
42122
}

0 commit comments

Comments
 (0)