Skip to content

Commit 102e603

Browse files
committed
[v18] avoid join fallback attempts on legitimate join failure
Backport #60673 to branch/v18
1 parent 305682f commit 102e603

File tree

2 files changed

+207
-19
lines changed

2 files changed

+207
-19
lines changed

lib/join/join_test.go

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,134 @@ func TestJoin(t *testing.T) {
225225
})
226226
}
227227

228+
// TestJoinError asserts that attempts to join with an invalid token return an
229+
// AccessDenied error and do not fall back to joining via the legacy join
230+
// service.
231+
func TestJoinError(t *testing.T) {
232+
t.Parallel()
233+
234+
token, err := types.NewProvisionTokenFromSpec("token1", time.Now().Add(time.Minute), types.ProvisionTokenSpecV2{
235+
Roles: []types.SystemRole{
236+
types.RoleNode,
237+
types.RoleProxy,
238+
},
239+
})
240+
require.NoError(t, err)
241+
242+
authService := newFakeAuthService(t)
243+
require.NoError(t, authService.Auth().UpsertToken(t.Context(), token))
244+
245+
proxy := newFakeProxy(authService)
246+
proxy.join(t)
247+
proxyListener, err := net.Listen("tcp", "127.0.0.1:0")
248+
require.NoError(t, err)
249+
t.Cleanup(func() { proxyListener.Close() })
250+
proxy.runGRPCServer(t, proxyListener)
251+
252+
// List on a free port just to guarantee an address that will reject/close
253+
// all connection attempts.
254+
badListener, err := net.Listen("tcp", "127.0.0.1:0")
255+
require.NoError(t, err)
256+
testutils.RunTestBackgroundTask(t.Context(), t, &testutils.TestBackgroundTask{
257+
Name: "bad listener",
258+
Task: func(ctx context.Context) error {
259+
for {
260+
conn, err := badListener.Accept()
261+
if err != nil {
262+
if ctx.Err() != nil {
263+
return ctx.Err()
264+
}
265+
return err
266+
}
267+
conn.Close()
268+
}
269+
},
270+
Terminate: badListener.Close,
271+
})
272+
273+
// Assert that the real AccessDenied error is returned with various
274+
// configurations joining via an auth or proxy address.
275+
for _, tc := range []struct {
276+
desc string
277+
joinParams joinclient.JoinParams
278+
assertErr assert.ErrorAssertionFunc
279+
}{
280+
{
281+
desc: "auth direct",
282+
joinParams: joinclient.JoinParams{
283+
AuthServers: []utils.NetAddr{utils.FromAddr(authService.TLS.Listener.Addr())},
284+
},
285+
assertErr: func(t assert.TestingT, err error, msgAndArgs ...any) bool {
286+
// Should get AccessDenied and should not fall back to joining
287+
// via the legacy service.
288+
return assert.ErrorAs(t, err, new(*trace.AccessDeniedError)) &&
289+
assert.NotErrorAs(t, err, new(*joinclient.LegacyJoinError))
290+
},
291+
},
292+
{
293+
// With teleport config v2 or certain bot configurations a proxy
294+
// address is passed in AuthServers, which supports both auth and
295+
// proxy addresses.
296+
desc: "proxy as auth",
297+
joinParams: joinclient.JoinParams{
298+
AuthServers: []utils.NetAddr{utils.FromAddr(proxyListener.Addr())},
299+
Insecure: true,
300+
},
301+
assertErr: func(t assert.TestingT, err error, msgAndArgs ...any) bool {
302+
// Should get AccessDenied and should not fall back to joining
303+
// via the legacy service.
304+
return assert.ErrorAs(t, err, new(*trace.AccessDeniedError)) &&
305+
assert.NotErrorAs(t, err, new(*joinclient.LegacyJoinError))
306+
},
307+
},
308+
{
309+
desc: "proxy direct",
310+
joinParams: joinclient.JoinParams{
311+
ProxyServer: utils.FromAddr(proxyListener.Addr()),
312+
Insecure: true,
313+
},
314+
assertErr: func(t assert.TestingT, err error, msgAndArgs ...any) bool {
315+
// Should get AccessDenied and should not fall back to joining
316+
// via the legacy service.
317+
return assert.ErrorAs(t, err, new(*trace.AccessDeniedError)) &&
318+
assert.NotErrorAs(t, err, new(*joinclient.LegacyJoinError))
319+
},
320+
},
321+
{
322+
desc: "bad auth address",
323+
joinParams: joinclient.JoinParams{
324+
AuthServers: []utils.NetAddr{utils.FromAddr(badListener.Addr())},
325+
},
326+
assertErr: func(t assert.TestingT, err error, msgAndArgs ...any) bool {
327+
// Should fall back to a legacy join attempt before failing.
328+
return assert.ErrorAs(t, err, new(*joinclient.LegacyJoinError))
329+
},
330+
},
331+
{
332+
desc: "bad proxy address",
333+
joinParams: joinclient.JoinParams{
334+
ProxyServer: utils.FromAddr(badListener.Addr()),
335+
},
336+
assertErr: func(t assert.TestingT, err error, msgAndArgs ...any) bool {
337+
// Should fall back to a legacy join attempt before failing.
338+
return assert.ErrorAs(t, err, new(*joinclient.LegacyJoinError))
339+
},
340+
},
341+
} {
342+
t.Run(tc.desc, func(t *testing.T) {
343+
joinParams := tc.joinParams
344+
joinParams.ID = state.IdentityID{
345+
Role: types.RoleInstance,
346+
NodeName: "test",
347+
}
348+
349+
joinParams.Token = "invalid"
350+
_, err = joinclient.Join(t.Context(), joinParams)
351+
tc.assertErr(t, err)
352+
})
353+
}
354+
}
355+
228356
type fakeAuthService struct {
229357
*authtest.Server
230358
}

lib/join/joinclient/join.go

Lines changed: 79 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ import (
2323
"encoding/pem"
2424
"errors"
2525
"log/slog"
26+
"strings"
2627

2728
"github.com/gravitational/trace"
2829
"golang.org/x/crypto/ssh"
30+
"google.golang.org/grpc/status"
2931

3032
"github.com/gravitational/teleport/api/client/proto"
3133
"github.com/gravitational/teleport/api/types"
@@ -61,9 +63,9 @@ func Join(ctx context.Context, params JoinParams) (*JoinResult, error) {
6163
}
6264
slog.InfoContext(ctx, "Trying to join with the new join service")
6365
result, err := joinNew(ctx, params)
64-
if trace.IsNotImplemented(err) || errors.As(err, new(*connectionError)) {
66+
if trace.IsNotImplemented(err) || isConnectionError(err) {
6567
// Fall back to joining via legacy service.
66-
slog.InfoContext(ctx, "Falling back to joining via the legacy join service", "error", err)
68+
slog.InfoContext(ctx, "Joining via new join service failed, falling back to joining via the legacy join service", "error", err)
6769
// Non-bots must provide their own host UUID when joining via legacy service.
6870
if params.ID.HostUUID == "" && params.ID.Role != types.RoleBot {
6971
hostID, err := hostid.Generate(ctx, params.JoinMethod)
@@ -74,7 +76,10 @@ func Join(ctx context.Context, params JoinParams) (*JoinResult, error) {
7476
params.ID.HostUUID = hostID
7577
}
7678
result, err := LegacyJoin(ctx, params)
77-
return result, trace.Wrap(err)
79+
if err != nil {
80+
return nil, trace.Wrap(&LegacyJoinError{err})
81+
}
82+
return result, nil
7883
}
7984
return result, trace.Wrap(err)
8085
}
@@ -91,34 +96,63 @@ func LegacyJoin(ctx context.Context, params JoinParams) (*JoinResult, error) {
9196

9297
func joinNew(ctx context.Context, params JoinParams) (*JoinResult, error) {
9398
if params.AuthClient != nil {
99+
slog.InfoContext(ctx, "Attempting to join cluster with existing Auth client")
94100
return joinViaAuthClient(ctx, params, params.AuthClient)
95101
}
96102
if !params.ProxyServer.IsEmpty() {
103+
slog.InfoContext(ctx, "Attempting to join cluster via Proxy")
97104
return joinViaProxy(ctx, params, params.ProxyServer.String())
98105
}
106+
99107
// params.AuthServers could contain auth or proxy addresses, try both.
100108
// params.CheckAndSetDefaults() asserts that this list is not empty when
101109
// AuthClient and ProxyServer are both unset.
110+
addr := params.AuthServers[0].String()
111+
slog := slog.With("addr", addr)
112+
113+
type strategy struct {
114+
name string
115+
fn func() (*JoinResult, error)
116+
}
117+
proxyStrategy := strategy{
118+
name: "proxy",
119+
fn: func() (*JoinResult, error) {
120+
return joinViaProxy(ctx, params, addr)
121+
},
122+
}
123+
authStrategy := strategy{
124+
name: "auth",
125+
fn: func() (*JoinResult, error) {
126+
return joinViaAuth(ctx, params)
127+
},
128+
}
129+
var strategies []strategy
102130
if authjoin.LooksLikeProxy(params.AuthServers) {
103-
proxyAddr := params.AuthServers[0].String()
104-
slog.InfoContext(ctx, "Attempting to join cluster, address looks like a Proxy", "addr", proxyAddr)
105-
result, proxyJoinErr := joinViaProxy(ctx, params, proxyAddr)
106-
if proxyJoinErr == nil {
131+
slog.InfoContext(ctx, "Attempting to join cluster, address looks like a Proxy")
132+
strategies = []strategy{proxyStrategy, authStrategy}
133+
} else {
134+
slog.InfoContext(ctx, "Attempting to join cluster, address looks like an Auth server")
135+
strategies = []strategy{authStrategy, proxyStrategy}
136+
}
137+
138+
var errs []error
139+
for i, strat := range strategies { //nolint:misspell // strat is an intentional abbreviation of strategy
140+
result, err := strat.fn()
141+
switch {
142+
case err == nil:
107143
return result, nil
144+
case !isConnectionError(err):
145+
// Non-connection errors are hard failures: return immediately.
146+
return nil, trace.Wrap(err, "joining via %s", strat.name)
147+
}
148+
// Connection error: keep for aggregate and try next strategy (if any).
149+
errs = append(errs, trace.Wrap(err, "joining via %s", strat.name))
150+
if i+1 < len(strategies) {
151+
slog.InfoContext(ctx, "Failed to join cluster with a connection error, will try next method",
152+
"method", strat.name, "next_method", strategies[i+1].name)
108153
}
109-
slog.InfoContext(ctx, "Joining via proxy failed, will try to join via Auth", "error", proxyJoinErr)
110-
result, authJoinErr := joinViaAuth(ctx, params)
111-
return result, trace.Wrap(authJoinErr)
112-
}
113-
addr := params.AuthServers[0].String()
114-
slog.InfoContext(ctx, "Attempting to join cluster, address looks like an Auth server", "addr", addr)
115-
result, authJoinErr := joinViaAuth(ctx, params)
116-
if authJoinErr == nil {
117-
return result, nil
118154
}
119-
slog.InfoContext(ctx, "Joining via auth failed, will try to join via Proxy", "error", authJoinErr)
120-
result, proxyJoinErr := joinViaProxy(ctx, params, addr)
121-
return result, trace.Wrap(proxyJoinErr)
155+
return nil, trace.NewAggregate(errs...)
122156
}
123157

124158
func joinViaProxy(ctx context.Context, params JoinParams, proxyAddr string) (*JoinResult, error) {
@@ -387,3 +421,29 @@ func (e *connectionError) Error() string {
387421
func (e *connectionError) Unwrap() error {
388422
return e.wrapped
389423
}
424+
425+
func isConnectionError(err error) bool {
426+
var ce *connectionError
427+
if errors.As(err, &ce) {
428+
return true
429+
}
430+
// It's possible to hit a gRPC status error like this when reading the
431+
// first response from the stream if connecting in insecure mode to a proxy
432+
// address provided as an auth address.
433+
statusErr, ok := status.FromError(err)
434+
return ok && strings.Contains(statusErr.Message(), "unexpected HTTP status code")
435+
}
436+
437+
// LegacyJoinError is returned when the join attempt failed while attempting to
438+
// join via the legacy join service.
439+
type LegacyJoinError struct {
440+
wrapped error
441+
}
442+
443+
func (e *LegacyJoinError) Error() string {
444+
return e.wrapped.Error()
445+
}
446+
447+
func (e *LegacyJoinError) Unwrap() error {
448+
return e.wrapped
449+
}

0 commit comments

Comments
 (0)