diff --git a/client/cmd/root.go b/client/cmd/root.go index 5084bd38a5d..11e5228f1e4 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -231,7 +231,7 @@ func FlagNameToEnvVar(cmdFlag string, prefix string) string { // DialClientGRPCServer returns client connection to the daemon server. func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, error) { - ctx, cancel := context.WithTimeout(ctx, time.Second*3) + ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() return grpc.DialContext( diff --git a/client/cmd/up.go b/client/cmd/up.go index e686625d670..d047c041e5e 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -230,7 +230,9 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager client := proto.NewDaemonServiceClient(conn) - status, err := client.Status(ctx, &proto.StatusRequest{}) + status, err := client.Status(ctx, &proto.StatusRequest{ + WaitForReady: func() *bool { b := true; return &b }(), + }) if err != nil { return fmt.Errorf("unable to get daemon status: %v", err) } diff --git a/client/embed/embed.go b/client/embed/embed.go index de83f9d96f3..0bfc7a37cbe 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -135,7 +135,7 @@ func (c *Client) Start(startCtx context.Context) error { // either startup error (permanent backoff err) or nil err (successful engine up) // TODO: make after-startup backoff err available - run := make(chan struct{}, 1) + run := make(chan struct{}) clientErr := make(chan error, 1) go func() { if err := client.Run(run); err != nil { diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 7ac950d852c..69e3f088c57 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -58,7 +58,7 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } -func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { +func CreateConnection(ctx context.Context, addr string, tlsEnabled bool) (*grpc.ClientConn, error) { transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) if tlsEnabled { certPool, err := x509.SystemCertPool() @@ -72,7 +72,7 @@ func CreateConnection(addr string, tlsEnabled bool) (*grpc.ClientConn, error) { })) } - connCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() conn, err := grpc.DialContext( diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index c633afc83f2..841e3c0f777 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.6 -// protoc v5.29.3 +// protoc v6.32.1 // source: daemon.proto package proto @@ -794,8 +794,10 @@ type StatusRequest struct { state protoimpl.MessageState `protogen:"open.v1"` GetFullPeerStatus bool `protobuf:"varint,1,opt,name=getFullPeerStatus,proto3" json:"getFullPeerStatus,omitempty"` ShouldRunProbes bool `protobuf:"varint,2,opt,name=shouldRunProbes,proto3" json:"shouldRunProbes,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // the UI do not using this yet, but CLIs could use it to wait until the status is ready + WaitForReady *bool `protobuf:"varint,3,opt,name=waitForReady,proto3,oneof" json:"waitForReady,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *StatusRequest) Reset() { @@ -842,6 +844,13 @@ func (x *StatusRequest) GetShouldRunProbes() bool { return false } +func (x *StatusRequest) GetWaitForReady() bool { + if x != nil && x.WaitForReady != nil { + return *x.WaitForReady + } + return false +} + type StatusResponse struct { state protoimpl.MessageState `protogen:"open.v1"` // status of the server. @@ -4673,10 +4682,12 @@ const file_daemon_proto_rawDesc = "" + "\f_profileNameB\v\n" + "\t_username\"\f\n" + "\n" + - "UpResponse\"g\n" + + "UpResponse\"\xa1\x01\n" + "\rStatusRequest\x12,\n" + "\x11getFullPeerStatus\x18\x01 \x01(\bR\x11getFullPeerStatus\x12(\n" + - "\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\"\x82\x01\n" + + "\x0fshouldRunProbes\x18\x02 \x01(\bR\x0fshouldRunProbes\x12'\n" + + "\fwaitForReady\x18\x03 \x01(\bH\x00R\fwaitForReady\x88\x01\x01B\x0f\n" + + "\r_waitForReady\"\x82\x01\n" + "\x0eStatusResponse\x12\x16\n" + "\x06status\x18\x01 \x01(\tR\x06status\x122\n" + "\n" + @@ -5231,6 +5242,7 @@ func file_daemon_proto_init() { } file_daemon_proto_msgTypes[1].OneofWrappers = []any{} file_daemon_proto_msgTypes[5].OneofWrappers = []any{} + file_daemon_proto_msgTypes[7].OneofWrappers = []any{} file_daemon_proto_msgTypes[26].OneofWrappers = []any{ (*PortInfo_Port)(nil), (*PortInfo_Range_)(nil), diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 0cd3579b913..5b27b4d9850 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -186,6 +186,8 @@ message UpResponse {} message StatusRequest{ bool getFullPeerStatus = 1; bool shouldRunProbes = 2; + // the UI do not using this yet, but CLIs could use it to wait until the status is ready + optional bool waitForReady = 3; } message StatusResponse{ diff --git a/client/server/server.go b/client/server/server.go index fae342f78b8..e6de608c529 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -67,6 +67,7 @@ type Server struct { proto.UnimplementedDaemonServiceServer clientRunning bool // protected by mutex clientRunningChan chan struct{} + clientGiveUpChan chan struct{} connectClient *internal.ConnectClient @@ -106,6 +107,10 @@ func (s *Server) Start() error { s.mutex.Lock() defer s.mutex.Unlock() + if s.clientRunning { + return nil + } + state := internal.CtxGetState(s.rootCtx) if err := handlePanicLog(); err != nil { @@ -175,12 +180,10 @@ func (s *Server) Start() error { return nil } - if s.clientRunning { - return nil - } s.clientRunning = true - s.clientRunningChan = make(chan struct{}, 1) - go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan) + s.clientRunningChan = make(chan struct{}) + s.clientGiveUpChan = make(chan struct{}) + go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) return nil } @@ -211,7 +214,7 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error { // connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional // mechanism to keep the client connected even when the connection is lost. // we cancel retry if the client receive a stop or down command, or if disable auto connect is configured. -func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) { +func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) { defer func() { s.mutex.Lock() s.clientRunning = false @@ -261,6 +264,10 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil if err := backoff.Retry(runOperation, backOff); err != nil { log.Errorf("operation failed: %v", err) } + + if giveUpChan != nil { + close(giveUpChan) + } } // loginAttempt attempts to login using the provided information. it returns a status in case something fails @@ -379,7 +386,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro if s.actCancel != nil { s.actCancel() } - ctx, cancel := context.WithCancel(s.rootCtx) + ctx, cancel := context.WithCancel(callerCtx) md, ok := metadata.FromIncomingContext(callerCtx) if ok { @@ -389,11 +396,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.actCancel = cancel s.mutex.Unlock() - if err := restoreResidualState(ctx, s.profileManager.GetStatePath()); err != nil { + if err := restoreResidualState(s.rootCtx, s.profileManager.GetStatePath()); err != nil { log.Warnf(errRestoreResidualState, err) } - state := internal.CtxGetState(ctx) + state := internal.CtxGetState(s.rootCtx) defer func() { status, err := state.Status() if err != nil || (status != internal.StatusNeedsLogin && status != internal.StatusLoginFailed) { @@ -606,6 +613,20 @@ func (s *Server) WaitSSOLogin(callerCtx context.Context, msg *proto.WaitSSOLogin // Up starts engine work in the daemon. func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpResponse, error) { s.mutex.Lock() + if s.clientRunning { + state := internal.CtxGetState(s.rootCtx) + status, err := state.Status() + if err != nil { + s.mutex.Unlock() + return nil, err + } + if status == internal.StatusNeedsLogin { + s.actCancel() + } + s.mutex.Unlock() + + return s.waitForUp(callerCtx) + } defer s.mutex.Unlock() if err := restoreResidualState(callerCtx, s.profileManager.GetStatePath()); err != nil { @@ -621,16 +642,16 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR if err != nil { return nil, err } + if status != internal.StatusIdle { return nil, fmt.Errorf("up already in progress: current status %s", status) } - // it should be nil here, but . + // it should be nil here, but in case it isn't we cancel it. if s.actCancel != nil { s.actCancel() } ctx, cancel := context.WithCancel(s.rootCtx) - md, ok := metadata.FromIncomingContext(callerCtx) if ok { ctx = metadata.NewOutgoingContext(ctx, md) @@ -673,26 +694,31 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String()) s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) + s.clientRunning = true + s.clientRunningChan = make(chan struct{}) + s.clientGiveUpChan = make(chan struct{}) + go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan) + + return s.waitForUp(callerCtx) +} + +// todo: handle potential race conditions +func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error) { timeoutCtx, cancel := context.WithTimeout(callerCtx, 50*time.Second) defer cancel() - if !s.clientRunning { - s.clientRunning = true - s.clientRunningChan = make(chan struct{}, 1) - go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan) - } - for { - select { - case <-s.clientRunningChan: - s.isSessionActive.Store(true) - return &proto.UpResponse{}, nil - case <-callerCtx.Done(): - log.Debug("context done, stopping the wait for engine to become ready") - return nil, callerCtx.Err() - case <-timeoutCtx.Done(): - log.Debug("up is timed out, stopping the wait for engine to become ready") - return nil, timeoutCtx.Err() - } + select { + case <-s.clientGiveUpChan: + return nil, fmt.Errorf("client gave up to connect") + case <-s.clientRunningChan: + s.isSessionActive.Store(true) + return &proto.UpResponse{}, nil + case <-callerCtx.Done(): + log.Debug("context done, stopping the wait for engine to become ready") + return nil, callerCtx.Err() + case <-timeoutCtx.Done(): + log.Debug("up is timed out, stopping the wait for engine to become ready") + return nil, timeoutCtx.Err() } } @@ -966,12 +992,46 @@ func (s *Server) Status( ctx context.Context, msg *proto.StatusRequest, ) (*proto.StatusResponse, error) { - if ctx.Err() != nil { - return nil, ctx.Err() - } - s.mutex.Lock() - defer s.mutex.Unlock() + clientRunning := s.clientRunning + s.mutex.Unlock() + + if msg.WaitForReady != nil && *msg.WaitForReady && clientRunning { + state := internal.CtxGetState(s.rootCtx) + status, err := state.Status() + if err != nil { + return nil, err + } + + if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting { + s.actCancel() + } + + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + loop: + for { + select { + case <-s.clientGiveUpChan: + ticker.Stop() + break loop + case <-s.clientRunningChan: + ticker.Stop() + break loop + case <-ticker.C: + status, err := state.Status() + if err != nil { + continue + } + if status != internal.StatusIdle && status != internal.StatusConnected && status != internal.StatusConnecting { + s.actCancel() + } + continue + case <-ctx.Done(): + return nil, ctx.Err() + } + } + } status, err := internal.CtxGetState(s.rootCtx).Status() if err != nil { diff --git a/client/server/server_test.go b/client/server/server_test.go index 45a1aa5c7e0..7559250039b 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -105,7 +105,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Setenv(maxRetryTimeVar, "5s") t.Setenv(retryMultiplierVar, "1") - s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil) + s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil) if counter < 3 { t.Fatalf("expected counter > 2, got %d", counter) } @@ -134,8 +134,12 @@ func TestServer_Up(t *testing.T) { profName := "default" + u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345") + require.NoError(t, err) + ic := profilemanager.ConfigInput{ - ConfigPath: filepath.Join(tempDir, profName+".json"), + ConfigPath: filepath.Join(tempDir, profName+".json"), + ManagementURL: u.String(), } _, err = profilemanager.UpdateOrCreateConfig(ic) @@ -153,16 +157,9 @@ func TestServer_Up(t *testing.T) { } s := New(ctx, "console", "", false, false) - err = s.Start() require.NoError(t, err) - u, err := url.Parse("http://non-existent-url-for-testing.invalid:12345") - require.NoError(t, err) - s.config = &profilemanager.Config{ - ManagementURL: u, - } - upCtx, cancel := context.WithTimeout(ctx, 1*time.Second) defer cancel() @@ -171,6 +168,7 @@ func TestServer_Up(t *testing.T) { Username: &currUser.Username, } _, err = s.Up(upCtx, upReq) + log.Errorf("error from Up: %v", err) assert.Contains(t, err.Error(), "context deadline exceeded") } diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 03cc5aec391..f30e965be85 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -52,7 +52,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) if err != nil { log.Printf("createConnection error: %v", err) return err diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 48d1ff04f83..5ca0c028237 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -57,7 +57,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo operation := func() error { var err error - conn, err = nbgrpc.CreateConnection(addr, tlsEnabled) + conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled) if err != nil { log.Printf("createConnection error: %v", err) return err