From 5d31e4e67836fde71fd9eb7b3231a151710a160a Mon Sep 17 00:00:00 2001 From: astraw99 Date: Fri, 6 Oct 2023 16:52:15 +0800 Subject: [PATCH] Add signal catch to stop the server gracefully --- cmd/hostpathplugin/main.go | 15 ++++++++++++--- internal/endpoint/endpoint.go | 1 - pkg/hostpath/hostpath.go | 6 ++++-- pkg/hostpath/server.go | 9 --------- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/cmd/hostpathplugin/main.go b/cmd/hostpathplugin/main.go index 9dcbe7238..b25324fb4 100644 --- a/cmd/hostpathplugin/main.go +++ b/cmd/hostpathplugin/main.go @@ -42,7 +42,7 @@ func main() { VendorVersion: version, } - flag.StringVar(&cfg.Endpoint, "endpoint", "unix://tmp/csi.sock", "CSI endpoint") + flag.StringVar(&cfg.Endpoint, "endpoint", "unix:///tmp/csi.sock", "CSI endpoint") flag.StringVar(&cfg.DriverName, "drivername", "hostpath.csi.k8s.io", "name of the driver") flag.StringVar(&cfg.StateDir, "statedir", "/csi-data-dir", "directory for storing state information across driver restarts, volumes and snapshots") flag.StringVar(&cfg.NodeID, "nodeid", "", "node id") @@ -124,9 +124,18 @@ func main() { os.Exit(1) } - if err := driver.Run(); err != nil { + // Wait for signal + stopCh := make(chan os.Signal, 1) + sigs := []os.Signal{ + syscall.SIGTERM, + syscall.SIGHUP, + syscall.SIGINT, + syscall.SIGQUIT, + } + signal.Notify(stopCh, sigs...) + + if err := driver.Run(stopCh); err != nil { fmt.Printf("Failed to run driver: %s", err.Error()) os.Exit(1) - } } diff --git a/internal/endpoint/endpoint.go b/internal/endpoint/endpoint.go index 4f85b5ba3..c856cb531 100644 --- a/internal/endpoint/endpoint.go +++ b/internal/endpoint/endpoint.go @@ -43,7 +43,6 @@ func Listen(endpoint string) (net.Listener, func(), error) { cleanup := func() {} if proto == "unix" { - addr = "/" + addr if err := os.Remove(addr); err != nil && !os.IsNotExist(err) { //nolint: vetshadow return nil, nil, fmt.Errorf("%s: %q", addr, err) } diff --git a/pkg/hostpath/hostpath.go b/pkg/hostpath/hostpath.go index 4abeac774..16779a1de 100644 --- a/pkg/hostpath/hostpath.go +++ b/pkg/hostpath/hostpath.go @@ -130,14 +130,16 @@ func NewHostPathDriver(cfg Config) (*hostPath, error) { return hp, nil } -func (hp *hostPath) Run() error { +func (hp *hostPath) Run(stopCh <-chan os.Signal) error { s := NewNonBlockingGRPCServer() var sms csi.SnapshotMetadataServer if hp.config.EnableSnapshotMetadata { sms = hp } s.Start(hp.config.Endpoint, hp, hp, hp, hp, sms) - s.Wait() + + <-stopCh + s.Stop() return nil } diff --git a/pkg/hostpath/server.go b/pkg/hostpath/server.go index 7c6971d44..f2ee39cf4 100644 --- a/pkg/hostpath/server.go +++ b/pkg/hostpath/server.go @@ -18,7 +18,6 @@ package hostpath import ( "encoding/json" - "sync" "golang.org/x/net/context" "google.golang.org/grpc" @@ -35,24 +34,17 @@ func NewNonBlockingGRPCServer() *nonBlockingGRPCServer { // NonBlocking server type nonBlockingGRPCServer struct { - wg sync.WaitGroup server *grpc.Server cleanup func() } func (s *nonBlockingGRPCServer) Start(endpoint string, ids csi.IdentityServer, cs csi.ControllerServer, ns csi.NodeServer, gcs csi.GroupControllerServer, sms csi.SnapshotMetadataServer) { - s.wg.Add(1) - go s.serve(endpoint, ids, cs, ns, gcs, sms) return } -func (s *nonBlockingGRPCServer) Wait() { - s.wg.Wait() -} - func (s *nonBlockingGRPCServer) Stop() { s.server.GracefulStop() s.cleanup() @@ -95,7 +87,6 @@ func (s *nonBlockingGRPCServer) serve(ep string, ids csi.IdentityServer, cs csi. klog.Infof("Listening for connections on address: %#v", listener.Addr()) server.Serve(listener) - } func logGRPC(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {