From c451e15178d4c2f0af7c46f2f8c7e9d7cf5a4587 Mon Sep 17 00:00:00 2001 From: Nic Klaassen Date: Wed, 1 Oct 2025 09:16:37 -0700 Subject: [PATCH] [v18] feat: client side impl for new join service Backport #59341 to branch/v18 --- lib/auth/join/join.go | 54 +++-- lib/client/proxy/insecure/insecure.go | 7 +- lib/join/join_test.go | 283 ++++++++--------------- lib/join/joinclient/join.go | 321 ++++++++++++++++++++++++++ lib/join/joinv1/client.go | 7 + 5 files changed, 458 insertions(+), 214 deletions(-) create mode 100644 lib/join/joinclient/join.go diff --git a/lib/auth/join/join.go b/lib/auth/join/join.go index 1a4a6a63a5b31..94ac4c64f2229 100644 --- a/lib/auth/join/join.go +++ b/lib/auth/join/join.go @@ -37,6 +37,7 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/client/webclient" + joinv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/join/v1" "github.com/gravitational/teleport/api/observability/tracing" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/aws" @@ -191,7 +192,7 @@ type RegisterParams struct { BoundKeypairParams *BoundKeypairParams } -func (r *RegisterParams) checkAndSetDefaults() error { +func (r *RegisterParams) CheckAndSetDefaults() error { if r.Clock == nil { r.Clock = clockwork.NewRealClock() } @@ -264,7 +265,7 @@ func Register(ctx context.Context, params RegisterParams) (result *RegisterResul ctx, span := tracer.Start(ctx, "Register") defer func() { tracing.EndSpan(span, err) }() - if err := params.checkAndSetDefaults(); err != nil { + if err := params.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } // Read in the token. The token can either be passed in or come from a file @@ -378,7 +379,7 @@ func Register(ctx context.Context, params RegisterParams) (result *RegisterResul if params.GetHostCredentials == nil { slog.DebugContext(ctx, "Missing client, it is not possible to register through proxy.") registerMethods = []registerMethod{registerThroughAuth} - } else if authServerIsProxy(params.AuthServers) { + } else if LooksLikeProxy(params.AuthServers) { slog.DebugContext(ctx, "The first specified auth server appears to be a proxy.") registerMethods = []registerMethod{registerThroughProxy, registerThroughAuth} } @@ -399,9 +400,9 @@ func Register(ctx context.Context, params RegisterParams) (result *RegisterResul return nil, trace.NewAggregate(collectedErrs...) } -// authServerIsProxy returns true if the first specified auth server +// LooksLikeProxy returns true if the first specified auth server // to register with appears to be a proxy. -func authServerIsProxy(servers []utils.NetAddr) bool { +func LooksLikeProxy(servers []utils.NetAddr) bool { if len(servers) == 0 { return false } @@ -506,25 +507,7 @@ func registerThroughAuth( ctx, span := tracer.Start(ctx, "registerThroughAuth") defer func() { tracing.EndSpan(span, err) }() - var client *authclient.Client - // Build a client for the Auth Server with different certificate validation - // depending on the configured values for Insecure, CAPins and CAPath. - switch { - case params.Insecure: - slog.WarnContext(ctx, "Insecure mode enabled. Auth Server cert will not be validated and CAPins and CAPath value will be ignored.") - client, err = insecureRegisterClient(ctx, params) - case len(params.CAPins) != 0: - // CAPins takes precedence over CAPath - client, err = pinRegisterClient(ctx, params) - case params.CAPath != "": - client, err = caPathRegisterClient(ctx, params) - default: - // We fall back to insecure mode here - this is a little odd but is - // necessary to preserve the behavior of registration. At a later date, - // we may consider making this an error asking the user to provide - // Insecure, CAPins or CAPath. - client, err = insecureRegisterClient(ctx, params) - } + client, err := NewAuthClient(ctx, params) if err != nil { return nil, trace.Wrap(err, "building auth client") } @@ -540,6 +523,7 @@ type AuthJoinClient interface { joinServiceClient RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error) Ping(ctx context.Context) (proto.PingResponse, error) + JoinV1Client() joinv1.JoinServiceClient } func registerThroughAuthClient( @@ -593,6 +577,28 @@ func getHostAddresses(params RegisterParams) []string { return utils.NetAddrsToStrings(params.AuthServers) } +// NewAuthClient returns a new auth client built according to the register +// params, preferring the authenticate the server via CA pins or a CA path and +// falling back to an insecure connection, unless insecure mode was explicitly enabled. +func NewAuthClient(ctx context.Context, params RegisterParams) (*authclient.Client, error) { + switch { + case params.Insecure: + slog.WarnContext(ctx, "Insecure mode enabled. Auth Server cert will not be validated and CAPins and CAPath value will be ignored.") + return insecureRegisterClient(ctx, params) + case len(params.CAPins) != 0: + // CAPins takes precedence over CAPath + return pinRegisterClient(ctx, params) + case params.CAPath != "": + return caPathRegisterClient(ctx, params) + default: + // We fall back to insecure mode here - this is a little odd but is + // necessary to preserve the behavior of registration. At a later date, + // we may consider making this an error asking the user to provide + // Insecure, CAPins or CAPath. + return insecureRegisterClient(ctx, params) + } +} + // insecureRegisterClient attempts to connects to the Auth Server using the // CA on disk. If no CA is found on disk, Teleport will not verify the Auth // Server it is connecting to. diff --git a/lib/client/proxy/insecure/insecure.go b/lib/client/proxy/insecure/insecure.go index 7818cf773c06b..766be08177e7e 100644 --- a/lib/client/proxy/insecure/insecure.go +++ b/lib/client/proxy/insecure/insecure.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" "golang.org/x/net/http2" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -35,6 +36,7 @@ import ( "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/metadata" + "github.com/gravitational/teleport/api/utils/grpc/interceptors" "github.com/gravitational/teleport/lib/srv/alpnproxy/common" "github.com/gravitational/teleport/lib/utils" ) @@ -85,9 +87,10 @@ func NewConnection( conn, err := grpc.Dial( params.ProxyServer, grpc.WithContextDialer(client.GRPCContextDialer(dialer)), - grpc.WithUnaryInterceptor(metadata.UnaryClientInterceptor), - grpc.WithStreamInterceptor(metadata.StreamClientInterceptor), + grpc.WithChainUnaryInterceptor(metadata.UnaryClientInterceptor, interceptors.GRPCClientUnaryErrorInterceptor), + grpc.WithChainStreamInterceptor(metadata.StreamClientInterceptor, interceptors.GRPCClientStreamErrorInterceptor), grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), + grpc.WithStatsHandler(otelgrpc.NewClientHandler()), ) return conn, trace.Wrap(err) } diff --git a/lib/join/join_test.go b/lib/join/join_test.go index 6bcd109fdd616..bdc04892aa873 100644 --- a/lib/join/join_test.go +++ b/lib/join/join_test.go @@ -18,9 +18,6 @@ package join_test import ( "context" - "crypto" - "crypto/tls" - "crypto/x509" "net" "slices" "testing" @@ -31,23 +28,26 @@ import ( "github.com/gravitational/trace" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/crypto/ssh" + "golang.org/x/net/http2" "google.golang.org/grpc" "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/testing/protocmp" + "github.com/gravitational/teleport/api/constants" joinv1proto "github.com/gravitational/teleport/api/gen/proto/go/teleport/join/v1" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/grpc/interceptors" + "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/auth/authtest" - "github.com/gravitational/teleport/lib/cryptosuites" + authjoin "github.com/gravitational/teleport/lib/auth/join" + "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/events" - "github.com/gravitational/teleport/lib/join/internal/messages" + "github.com/gravitational/teleport/lib/join/joinclient" "github.com/gravitational/teleport/lib/join/joinv1" - "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/srv/alpnproxy/common" + "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/testutils" ) @@ -93,11 +93,10 @@ func TestJoin(t *testing.T) { proxy.runGRPCServer(t, proxyListener) t.Run("invalid token", func(t *testing.T) { - _, _, err := join( + _, err := joinViaProxy( t.Context(), - proxyListener.Addr(), - insecure.NewCredentials(), "invalidtoken", + proxyListener.Addr(), ) require.ErrorAs(t, err, new(*trace.AccessDeniedError)) ctx := t.Context() @@ -129,51 +128,40 @@ func TestJoin(t *testing.T) { }) t.Run("join and rejoin", func(t *testing.T) { - // Node joins by connecting to the proxy's gRPC service. - joinResult, signer, err := join( + // Node initially joins by connecting to the proxy's gRPC service. + identity, err := joinViaProxy( t.Context(), - proxyListener.Addr(), - insecure.NewCredentials(), token1.GetName(), + proxyListener.Addr(), ) - // Make sure the result contains a host ID and expected certificate roles. - require.NoError(t, err) - require.NotNil(t, joinResult.HostID) - require.NotEmpty(t, joinResult.HostID) - cert, err := x509.ParseCertificate(joinResult.Certificates.TLSCert) - require.NoError(t, err) - identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter) require.NoError(t, err) - require.Len(t, identity.Groups, 1) - require.Equal(t, identity.Groups[0], types.RoleInstance.String()) + // Make sure the result contains a host ID and expected certificate roles. + require.NotEmpty(t, identity.ID.HostUUID) + require.Equal(t, types.RoleInstance, identity.ID.Role) expectedSystemRoles := slices.DeleteFunc( token1.GetRoles().StringSlice(), func(s string) bool { return s == types.RoleInstance.String() }, ) require.ElementsMatch(t, expectedSystemRoles, identity.SystemRoles) + // Build an auth client with the new identity. + tlsConfig, err := identity.TLSConfig(nil /*cipherSuites*/) + require.NoError(t, err) + authClient, err := authService.TLS.NewClientWithCert(tlsConfig.Certificates[0]) + require.NoError(t, err) + // Node can rejoin with a different token by dialing the auth service - // with its original credentials (for this test we omit the details of - // the proxy's mTLS tunnel dialing and let the node dial auth - // directly). + // with an auth client authenticed with its original credentials. // // It should get back its original host ID and the combined roles of // its original certificate and the new token. - creds, err := clientCreds(signer, joinResult.Certificates) - require.NoError(t, err) - rejoinResult, _, err := join( + newIdentity, err := rejoinViaAuthClient( t.Context(), - authService.TLS.Listener.Addr(), - creds, token2.GetName(), + authClient, ) require.NoError(t, err) - cert, err = x509.ParseCertificate(rejoinResult.Certificates.TLSCert) - require.NoError(t, err) - identity, err = tlsca.FromSubject(cert.Subject, cert.NotAfter) - require.NoError(t, err) - require.Len(t, identity.Groups, 1) - require.Equal(t, identity.Groups[0], types.RoleInstance.String()) + require.Equal(t, identity.ID, newIdentity.ID) expectedSystemRoles = slices.DeleteFunc( apiutils.Deduplicate(slices.Concat( token1.GetRoles().StringSlice(), @@ -181,31 +169,29 @@ func TestJoin(t *testing.T) { )), func(s string) bool { return s == types.RoleInstance.String() }, ) - require.ElementsMatch(t, expectedSystemRoles, identity.SystemRoles) - - // The node gets back its original host ID when rejoining with an - // authenticated client. - require.Equal(t, joinResult.HostID, rejoinResult.HostID) + require.ElementsMatch(t, expectedSystemRoles, newIdentity.SystemRoles) }) t.Run("join and rejoin with bad token", func(t *testing.T) { // Node joins by connecting to the proxy's gRPC service. - joinResult, signer, err := join( + identity, err := joinViaProxy( t.Context(), - proxyListener.Addr(), - insecure.NewCredentials(), token1.GetName(), + proxyListener.Addr(), ) require.NoError(t, err) - // Node the tries to rejoin with valid certs but an invalid token. - creds, err := clientCreds(signer, joinResult.Certificates) + // Build an auth client with the new identity. + tlsConfig, err := identity.TLSConfig(nil /*cipherSuites*/) + require.NoError(t, err) + authClient, err := authService.TLS.NewClientWithCert(tlsConfig.Certificates[0]) require.NoError(t, err) - _, _, err = join( + + // Node the tries to rejoin with valid certs but an invalid token. + _, err = rejoinViaAuthClient( t.Context(), - authService.TLS.Listener.Addr(), - creds, "invalidtoken", + authClient, ) require.ErrorAs(t, err, new(*trace.AccessDeniedError)) ctx := t.Context() @@ -272,8 +258,8 @@ func (s *fakeAuthService) lastEvent(ctx context.Context, eventType string) (apie } type fakeProxy struct { - auth *fakeAuthService - authenticatedAuthCreds credentials.TransportCredentials + auth *fakeAuthService + identity *state.Identity } func newFakeProxy(auth *fakeAuthService) *fakeProxy { @@ -285,54 +271,35 @@ func newFakeProxy(auth *fakeAuthService) *fakeProxy { func (p *fakeProxy) join(t *testing.T) { unauthenticatedAuthClt, err := p.auth.NewClient(authtest.TestNop()) require.NoError(t, err) - joinClient := joinv1.NewClient(unauthenticatedAuthClt.JoinV1Client()) - // Initiate the join request and get a client stream. - stream, err := joinClient.Join(t.Context()) + joinResult, err := joinclient.Join(t.Context(), joinclient.JoinParams{ + Token: "token1", + ID: state.IdentityID{ + Role: types.RoleInstance, + NodeName: "proxy", + }, + AuthClient: unauthenticatedAuthClt, + DNSNames: []string{"proxy"}, + AdditionalPrincipals: []string{"127.0.0.1"}, + }) require.NoError(t, err) - // Send the ClientInit messaage. - require.NoError(t, stream.Send(&messages.ClientInit{ - TokenName: "token1", - SystemRole: types.RoleInstance.String(), - })) - - // Wait for the ServerInit response. - serverInit, err := messages.RecvResponse[*messages.ServerInit](stream) + privateKeyPEM, err := keys.MarshalPrivateKey(joinResult.PrivateKey) require.NoError(t, err) - - require.Equal(t, string(types.JoinMethodToken), serverInit.JoinMethod) - - // Generate host keys with the suite from the ServerInit message. - hostKeys, err := genHostKeys(t.Context(), serverInit.SignatureAlgorithmSuite) + p.identity, err = state.ReadIdentityFromKeyPair(privateKeyPEM, joinResult.Certs) require.NoError(t, err) +} - // Send the TokenInit message. - require.NoError(t, stream.Send(&messages.TokenInit{ - ClientParams: messages.ClientParams{ - HostParams: &messages.HostParams{ - PublicKeys: messages.PublicKeys{ - PublicTLSKey: hostKeys.tlsPubKey, - PublicSSHKey: hostKeys.sshPubKey, - }, - HostName: "proxy", - AdditionalPrincipals: []string{"proxy"}, - }, - }, - })) - - // Wait for the result from the server. - result, err := messages.RecvResponse[*messages.HostResult](stream) +func (p *fakeProxy) runGRPCServer(t *testing.T, l net.Listener) { + tlsConfig, err := p.identity.TLSConfig(nil /*cipherSuites*/) require.NoError(t, err) + // Set NextProtos such that the ALPN conn upgrade test passes. + tlsConfig.NextProtos = []string{string(constants.ALPNSNIProtocolReverseTunnel), string(common.ProtocolProxyGRPCInsecure), http2.NextProtoTLS} - // Save the host credentials we got from the successful join. - p.authenticatedAuthCreds, err = clientCreds(hostKeys.tls, result.Certificates) - require.NoError(t, err) -} + grpcCreds := credentials.NewTLS(tlsConfig) -func (p *fakeProxy) runGRPCServer(t *testing.T, l net.Listener) { authenticatedAuthClientConn, err := grpc.NewClient(p.auth.TLS.Listener.Addr().String(), - grpc.WithTransportCredentials(p.authenticatedAuthCreds), + grpc.WithTransportCredentials(grpcCreds), grpc.WithStreamInterceptor(interceptors.GRPCClientStreamErrorInterceptor), ) require.NoError(t, err) @@ -341,6 +308,7 @@ func (p *fakeProxy) runGRPCServer(t *testing.T, l net.Listener) { }) grpcServer := grpc.NewServer( + grpc.Creds(grpcCreds), grpc.StreamInterceptor(interceptors.GRPCServerStreamErrorInterceptor), ) joinv1.RegisterProxyForwardingJoinServiceServer(grpcServer, joinv1proto.NewJoinServiceClient(authenticatedAuthClientConn)) @@ -357,116 +325,55 @@ func (p *fakeProxy) runGRPCServer(t *testing.T, l net.Listener) { }) } -func join( +func joinViaProxy( ctx context.Context, - addr net.Addr, - creds credentials.TransportCredentials, token string, -) (*messages.HostResult, crypto.Signer, error) { - conn, err := grpc.NewClient(addr.String(), - grpc.WithTransportCredentials(creds), - grpc.WithStreamInterceptor(interceptors.GRPCClientStreamErrorInterceptor), - ) - if err != nil { - return nil, nil, trace.Wrap(err) - } - defer conn.Close() - joinClient := joinv1.NewClient(joinv1proto.NewJoinServiceClient(conn)) - - // Initiate the join request. - stream, err := joinClient.Join(ctx) - if err != nil { - return nil, nil, trace.Wrap(err) - } - - // Send the ClientInit message. - err = stream.Send(&messages.ClientInit{ - TokenName: token, - SystemRole: types.RoleInstance.String(), + addr net.Addr, +) (*state.Identity, error) { + joinResult, err := joinclient.Join(ctx, joinclient.JoinParams{ + Token: token, + ID: state.IdentityID{ + Role: types.RoleInstance, + NodeName: "node", + }, + ProxyServer: utils.NetAddr{ + AddrNetwork: addr.Network(), + Addr: addr.String(), + }, + AdditionalPrincipals: []string{"node"}, + // The proxy's TLS cert for the test is not trusted. + Insecure: true, }) if err != nil { - return nil, nil, trace.Wrap(err) - } - - // Wait for the ServerInit response. - serverInit, err := messages.RecvResponse[*messages.ServerInit](stream) - if err != nil { - return nil, nil, trace.Wrap(err) - } - - // Generate host keys with the suite from the ServerInit message. - hostKeys, err := genHostKeys(ctx, serverInit.SignatureAlgorithmSuite) - if err != nil { - return nil, nil, trace.Wrap(err) - } - - // Send the TokenInit message with the host keys. - if err := stream.Send(&messages.TokenInit{ - ClientParams: messages.ClientParams{ - HostParams: &messages.HostParams{ - PublicKeys: messages.PublicKeys{ - PublicTLSKey: hostKeys.tlsPubKey, - PublicSSHKey: hostKeys.sshPubKey, - }, - HostName: "node", - }, - }, - }); err != nil { - return nil, nil, trace.Wrap(err) + return nil, trace.Wrap(err) } - - // Wait for the result. - result, err := messages.RecvResponse[*messages.HostResult](stream) + privateKeyPEM, err := keys.MarshalPrivateKey(joinResult.PrivateKey) if err != nil { - return nil, nil, trace.Wrap(err) - } - return result, hostKeys.tls, nil -} - -func clientCreds(tlsKey crypto.PrivateKey, certs messages.Certificates) (credentials.TransportCredentials, error) { - caPool := x509.NewCertPool() - for _, caCertDER := range certs.TLSCACerts { - caCert, err := x509.ParseCertificate(caCertDER) - if err != nil { - return nil, trace.Wrap(err) - } - caPool.AddCert(caCert) + return nil, trace.Wrap(err) } - return credentials.NewTLS(&tls.Config{ - Certificates: []tls.Certificate{{ - Certificate: [][]byte{certs.TLSCert}, - PrivateKey: tlsKey, - }}, - RootCAs: caPool, - ServerName: "teleport.cluster.local", - }), nil + return state.ReadIdentityFromKeyPair(privateKeyPEM, joinResult.Certs) } -type hostKeys struct { - tls crypto.Signer - tlsPubKey []byte - ssh ssh.Signer - sshPubKey []byte -} - -func genHostKeys(ctx context.Context, suite types.SignatureAlgorithmSuite) (*hostKeys, error) { - signer, err := cryptosuites.GenerateKey(ctx, cryptosuites.StaticAlgorithmSuite(suite), cryptosuites.HostIdentity) - if err != nil { - return nil, trace.Wrap(err) - } - tlsPubKey, err := x509.MarshalPKIXPublicKey(signer.Public()) +func rejoinViaAuthClient( + ctx context.Context, + token string, + authClient authjoin.AuthJoinClient, +) (*state.Identity, error) { + joinResult, err := joinclient.Join(ctx, joinclient.JoinParams{ + Token: token, + ID: state.IdentityID{ + Role: types.RoleInstance, + NodeName: "node", + }, + AdditionalPrincipals: []string{"node"}, + AuthClient: authClient, + }) if err != nil { return nil, trace.Wrap(err) } - sshKey, err := ssh.NewSignerFromSigner(signer) + privateKeyPEM, err := keys.MarshalPrivateKey(joinResult.PrivateKey) if err != nil { return nil, trace.Wrap(err) } - sshPubKey := sshKey.PublicKey().Marshal() - return &hostKeys{ - tls: signer, - tlsPubKey: tlsPubKey, - ssh: sshKey, - sshPubKey: sshPubKey, - }, nil + return state.ReadIdentityFromKeyPair(privateKeyPEM, joinResult.Certs) } diff --git a/lib/join/joinclient/join.go b/lib/join/joinclient/join.go new file mode 100644 index 0000000000000..d20b35777576b --- /dev/null +++ b/lib/join/joinclient/join.go @@ -0,0 +1,321 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package joinclient + +import ( + "context" + "crypto" + "crypto/x509" + "encoding/pem" + "log/slog" + + "github.com/gravitational/trace" + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" + authjoin "github.com/gravitational/teleport/lib/auth/join" + proxyinsecureclient "github.com/gravitational/teleport/lib/client/proxy/insecure" + "github.com/gravitational/teleport/lib/cryptosuites" + "github.com/gravitational/teleport/lib/join/internal/messages" + "github.com/gravitational/teleport/lib/join/joinv1" +) + +type JoinParams = authjoin.RegisterParams +type JoinResult = authjoin.RegisterResult + +// Join is used to join a cluster. A host or bot calls this with the name of a +// provision token to get its initial certificates. +func Join(ctx context.Context, params JoinParams) (*JoinResult, error) { + if err := params.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + slog.InfoContext(ctx, "Trying to join with the new join service") + result, err := joinNew(ctx, params) + if trace.IsNotImplemented(err) { + // Fall back to joining via legacy service. + slog.InfoContext(ctx, "Falling back to joining via the legacy join service", "error", err) + result, err := authjoin.Register(ctx, params) + return result, trace.Wrap(err) + } + return result, trace.Wrap(err) +} + +func joinNew(ctx context.Context, params JoinParams) (*JoinResult, error) { + if params.AuthClient != nil { + return joinViaAuthClient(ctx, params, params.AuthClient) + } + if !params.ProxyServer.IsEmpty() { + return joinViaProxy(ctx, params, params.ProxyServer.String()) + } + // params.AuthServers could contain auth or proxy addresses, try both. + // params.CheckAndSetDefaults() asserts that this list is not empty when + // AuthClient and ProxyServer are both unset. + if authjoin.LooksLikeProxy(params.AuthServers) { + proxyAddr := params.AuthServers[0].String() + slog.InfoContext(ctx, "Attempting to join cluster, address looks like a Proxy", "addr", proxyAddr) + result, proxyJoinErr := joinViaProxy(ctx, params, proxyAddr) + if proxyJoinErr == nil { + return result, nil + } + slog.InfoContext(ctx, "Joining via proxy failed, will try to join via Auth", "error", proxyJoinErr) + result, authJoinErr := joinViaAuth(ctx, params) + return result, trace.Wrap(authJoinErr) + } + addr := params.AuthServers[0].String() + slog.InfoContext(ctx, "Attempting to join cluster, address looks like an Auth server", "addr", addr) + result, authJoinErr := joinViaAuth(ctx, params) + if authJoinErr == nil { + return result, nil + } + slog.InfoContext(ctx, "Joining via auth failed, will try to join via Proxy", "error", authJoinErr) + result, proxyJoinErr := joinViaProxy(ctx, params, addr) + return result, trace.Wrap(proxyJoinErr) +} + +func joinViaProxy(ctx context.Context, params JoinParams, proxyAddr string) (*JoinResult, error) { + // Connect to the proxy's insecure gRPC listener (this is regular TLS, the + // client is not authenticated because it doesn't have certs yet). + conn, err := proxyinsecureclient.NewConnection(ctx, + proxyinsecureclient.ConnectionConfig{ + ProxyServer: proxyAddr, + CipherSuites: params.CipherSuites, + Clock: params.Clock, + Insecure: params.Insecure, + Log: slog.Default(), + }, + ) + if err != nil { + return nil, trace.Wrap(err) + } + defer conn.Close() + return joinWithClient(ctx, params, joinv1.NewClientFromConn(conn)) +} + +func joinViaAuth(ctx context.Context, params JoinParams) (*JoinResult, error) { + authClient, err := authjoin.NewAuthClient(ctx, params) + if err != nil { + return nil, trace.Wrap(err, "building auth client") + } + defer authClient.Close() + return joinViaAuthClient(ctx, params, authClient) +} + +func joinViaAuthClient(ctx context.Context, params JoinParams, authClient authjoin.AuthJoinClient) (*JoinResult, error) { + return joinWithClient(ctx, params, joinv1.NewClient(authClient.JoinV1Client())) +} + +func joinWithClient(ctx context.Context, params JoinParams, client *joinv1.Client) (*JoinResult, error) { + // Clients may specify the join method or not, to let the server choose the + // method based on the provsion token. + var joinMethodPtr *string + switch params.JoinMethod { + case types.JoinMethodUnspecified: + // leave joinMethodPtr nil to let the server pick based on the token + case types.JoinMethodToken: + joinMethod := string(params.JoinMethod) + joinMethodPtr = &joinMethod + default: + return nil, trace.NotImplemented("new join service is not implemented for method %v", params.JoinMethod) + } + + // Initiate the join request, using a cancelable context to make sure the + // stream is closed when this function returns. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + stream, err := client.Join(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + defer stream.CloseSend() + + // Send the ClientInit message with the intended join method, token name, + // and system role. + if err := stream.Send(&messages.ClientInit{ + JoinMethod: joinMethodPtr, + TokenName: params.Token, + SystemRole: params.ID.Role.String(), + }); err != nil { + return nil, trace.Wrap(err) + } + + // Receive the ServerInit message. + serverInit, err := messages.RecvResponse[*messages.ServerInit](stream) + if err != nil { + return nil, trace.Wrap(err) + } + + // Generate keys based on the signature algorithm suite from the ServerInit message. + signer, publicKeys, err := generateKeys(ctx, serverInit.SignatureAlgorithmSuite) + if err != nil { + return nil, trace.Wrap(err) + } + // Build the ClientParams message that will be sent for all join methods. + clientParams := makeClientParams(params, publicKeys) + + // Delegate out to the handler for the specific join method. + if err := joinWithMethod(stream, clientParams, serverInit.JoinMethod); err != nil { + return nil, trace.Wrap(err) + } + + // Receive the final result message. + if params.ID.Role == types.RoleBot { + botResult, err := messages.RecvResponse[*messages.BotResult](stream) + if err != nil { + return nil, trace.Wrap(err) + } + return makeJoinResult(signer, botResult.Certificates) + } + hostResult, err := messages.RecvResponse[*messages.HostResult](stream) + if err != nil { + return nil, trace.Wrap(err) + } + return makeJoinResult(signer, hostResult.Certificates) +} + +func joinWithMethod( + stream messages.ClientStream, + clientParams messages.ClientParams, + method string, +) error { + switch types.JoinMethod(method) { + case types.JoinMethodToken: + return trace.Wrap(tokenJoin(stream, clientParams)) + default: + // TODO(nklaassen): implement remaining join methods. + return trace.NotImplemented("server selected join method %v which is not supported by this client", method) + } +} + +func tokenJoin( + stream messages.ClientStream, + clientParams messages.ClientParams, +) error { + // The token join method is relatively simple, the flow is + // + // client->server ClientInit + // client<-server ServerInit + // client->server Tokeninit + // client<-server Result + // + // At this point the ServerInit messages has already been received, all + // that's left is to send the TokenInit message, the caller will handle + // receiving the final result. + tokenInitMsg := &messages.TokenInit{ + ClientParams: clientParams, + } + return trace.Wrap(stream.Send(tokenInitMsg)) +} + +func makeClientParams(params JoinParams, publicKeys *messages.PublicKeys) messages.ClientParams { + if params.ID.Role == types.RoleBot { + return messages.ClientParams{ + BotParams: &messages.BotParams{ + PublicKeys: *publicKeys, + Expires: params.Expires, + }, + } + } + return messages.ClientParams{ + HostParams: &messages.HostParams{ + PublicKeys: *publicKeys, + HostName: params.ID.NodeName, + AdditionalPrincipals: params.AdditionalPrincipals, + DNSNames: params.DNSNames, + }, + } +} + +func makeJoinResult(signer crypto.Signer, certs messages.Certificates) (*JoinResult, error) { + // Callers expect proto.Certs with PEM-formatted TLS certs and + // authorized_keys formated SSH certs/keys. + sshCert, err := toAuthorizedKey(certs.SSHCert) + if err != nil { + return nil, trace.Wrap(err) + } + sshCAKeys, err := toAuthorizedKeys(certs.SSHCAKeys) + if err != nil { + return nil, trace.Wrap(err) + } + return &JoinResult{ + Certs: &proto.Certs{ + TLS: pemEncodeTLSCert(certs.TLSCert), + TLSCACerts: pemEncodeTLSCerts(certs.TLSCACerts), + SSH: sshCert, + SSHCACerts: sshCAKeys, // SSHCACerts is a misnomer, SSH CAs are just public keys. + }, + PrivateKey: signer, + }, nil +} + +func toAuthorizedKeys(wireFormats [][]byte) ([][]byte, error) { + out := make([][]byte, len(wireFormats)) + for i, wireFormat := range wireFormats { + var err error + out[i], err = toAuthorizedKey(wireFormat) + if err != nil { + return nil, trace.Wrap(err) + } + } + return out, nil +} + +func toAuthorizedKey(wireFormat []byte) ([]byte, error) { + sshPub, err := ssh.ParsePublicKey(wireFormat) + if err != nil { + return nil, trace.Wrap(err) + } + return ssh.MarshalAuthorizedKey(sshPub), nil +} + +func pemEncodeTLSCerts(rawCerts [][]byte) [][]byte { + out := make([][]byte, len(rawCerts)) + for i, rawCert := range rawCerts { + out[i] = pemEncodeTLSCert(rawCert) + } + return out +} + +func pemEncodeTLSCert(rawCert []byte) []byte { + return pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: rawCert, + }) +} + +func generateKeys(ctx context.Context, suite types.SignatureAlgorithmSuite) (crypto.Signer, *messages.PublicKeys, error) { + signer, err := cryptosuites.GenerateKey( + ctx, + cryptosuites.StaticAlgorithmSuite(suite), + cryptosuites.HostIdentity, + ) + if err != nil { + return nil, nil, trace.Wrap(err) + } + tlsPub, err := x509.MarshalPKIXPublicKey(signer.Public()) + if err != nil { + return nil, nil, trace.Wrap(err) + } + sshPub, err := ssh.NewPublicKey(signer.Public()) + if err != nil { + return nil, nil, trace.Wrap(err) + } + return signer, &messages.PublicKeys{ + PublicTLSKey: tlsPub, + PublicSSHKey: sshPub.Marshal(), + }, nil +} diff --git a/lib/join/joinv1/client.go b/lib/join/joinv1/client.go index f7d1d9fae1ae4..994d6a3f54e97 100644 --- a/lib/join/joinv1/client.go +++ b/lib/join/joinv1/client.go @@ -41,6 +41,13 @@ func NewClient(grpcClient joinv1.JoinServiceClient) *Client { } } +// NewClientFromConn returns a new [Client] wrapping plain gRPC ClientConn. +func NewClientFromConn(cc *grpc.ClientConn) *Client { + return &Client{ + grpcClient: joinv1.NewJoinServiceClient(cc), + } +} + // Join implements cluster joining for nodes and bots. func (c *Client) Join(ctx context.Context) (messages.ClientStream, error) { ctx, cancel := context.WithCancelCause(ctx)