diff --git a/cmd/hostpathplugin/main.go b/cmd/hostpathplugin/main.go index 9dcbe7238..41fb3e8d1 100644 --- a/cmd/hostpathplugin/main.go +++ b/cmd/hostpathplugin/main.go @@ -46,6 +46,7 @@ func main() { flag.StringVar(&cfg.DriverName, "drivername", "hostpath.csi.k8s.io", "name of the driver") flag.StringVar(&cfg.StateDir, "statedir", "/csi-data-dir", "directory for storing state information across driver restarts, volumes and snapshots") flag.StringVar(&cfg.NodeID, "nodeid", "", "node id") + flag.StringVar(&cfg.SecondaryNodeID, "worker-node-id", "", "worker node id") flag.BoolVar(&cfg.Ephemeral, "ephemeral", false, "publish volumes in ephemeral mode even if kubelet did not ask for it (only needed for Kubernetes 1.15)") flag.Int64Var(&cfg.MaxVolumesPerNode, "maxvolumespernode", 0, "limit of volumes per node") flag.Var(&cfg.Capacity, "capacity", "Simulate storage capacity. The parameter is = where is the value of a 'kind' storage class parameter and is the total amount of bytes for that kind. The flag may be used multiple times to configure different kinds.") diff --git a/pkg/hostpath/controllerserver.go b/pkg/hostpath/controllerserver.go index 5dd570419..a56c430af 100644 --- a/pkg/hostpath/controllerserver.go +++ b/pkg/hostpath/controllerserver.go @@ -117,6 +117,9 @@ func (hp *hostPath) CreateVolume(ctx context.Context, req *csi.CreateVolumeReque topologies := []*csi.Topology{} if hp.config.EnableTopology { topologies = append(topologies, &csi.Topology{Segments: map[string]string{TopologyKeyNode: hp.config.NodeID}}) + if hp.config.SecondaryNodeID != "" { + topologies = append(topologies, &csi.Topology{Segments: map[string]string{TopologyKeyNode: hp.config.SecondaryNodeID}}) + } } // Need to check for already existing volume name, and if found @@ -301,7 +304,7 @@ func (hp *hostPath) ControllerPublishVolume(ctx context.Context, req *csi.Contro return nil, status.Error(codes.InvalidArgument, "Volume Capabilities cannot be empty") } - if req.NodeId != hp.config.NodeID { + if req.NodeId != hp.config.NodeID && req.NodeId != hp.config.SecondaryNodeID { return nil, status.Errorf(codes.NotFound, "Not matching Node ID %s to hostpath Node ID %s", req.NodeId, hp.config.NodeID) } @@ -358,6 +361,11 @@ func (hp *hostPath) ControllerUnpublishVolume(ctx context.Context, req *csi.Cont hp.mutex.Lock() defer hp.mutex.Unlock() + err := hp.state.SafeReloadData() + if err != nil { + return nil, err + } + vol, err := hp.state.GetVolumeByID(req.VolumeId) if err != nil { // Not an error: a non-existent volume is not published. diff --git a/pkg/hostpath/hostpath.go b/pkg/hostpath/hostpath.go index 4abeac774..6095d6647 100644 --- a/pkg/hostpath/hostpath.go +++ b/pkg/hostpath/hostpath.go @@ -65,10 +65,12 @@ type hostPath struct { } type Config struct { - DriverName string - Endpoint string - ProxyEndpoint string - NodeID string + DriverName string + Endpoint string + ProxyEndpoint string + NodeID string + // SecondaryNodeID can be used to deploy hostpath with more than one topology + SecondaryNodeID string VendorVersion string StateDir string MaxVolumesPerNode int64 diff --git a/pkg/hostpath/nodeserver.go b/pkg/hostpath/nodeserver.go index def000036..05ddf9c3b 100644 --- a/pkg/hostpath/nodeserver.go +++ b/pkg/hostpath/nodeserver.go @@ -64,6 +64,11 @@ func (hp *hostPath) NodePublishVolume(ctx context.Context, req *csi.NodePublishV hp.mutex.Lock() defer hp.mutex.Unlock() + err := hp.state.SafeReloadData() + if err != nil { + return nil, err + } + mounter := mount.New("") // if ephemeral is specified, create volume here to avoid errors @@ -223,6 +228,11 @@ func (hp *hostPath) NodeUnpublishVolume(ctx context.Context, req *csi.NodeUnpubl hp.mutex.Lock() defer hp.mutex.Unlock() + err := hp.state.SafeReloadData() + if err != nil { + return nil, err + } + vol, err := hp.state.GetVolumeByID(volumeID) if err != nil { return nil, err @@ -286,6 +296,11 @@ func (hp *hostPath) NodeStageVolume(ctx context.Context, req *csi.NodeStageVolum hp.mutex.Lock() defer hp.mutex.Unlock() + err := hp.state.SafeReloadData() + if err != nil { + return nil, err + } + vol, err := hp.state.GetVolumeByID(req.VolumeId) if err != nil { return nil, err @@ -329,6 +344,11 @@ func (hp *hostPath) NodeUnstageVolume(ctx context.Context, req *csi.NodeUnstageV hp.mutex.Lock() defer hp.mutex.Unlock() + err := hp.state.SafeReloadData() + if err != nil { + return nil, err + } + vol, err := hp.state.GetVolumeByID(req.VolumeId) if err != nil { return nil, err diff --git a/pkg/state/state.go b/pkg/state/state.go index cd207a371..e7465322f 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -85,6 +85,11 @@ type GroupSnapshot struct { // access and change state. All error messages contain gRPC // status codes and can be returned without wrapping. type State interface { + // SafeReloadData reloads the volume information safely + // from underlying state file. If there are any errors + // reading old data, then existing state is restored + SafeReloadData() error + // GetVolumeByID retrieves a volume by its unique ID or returns // an error including that ID when not found. GetVolumeByID(volID string) (Volume, error) @@ -203,6 +208,29 @@ func (s *state) restore() error { return nil } +func (s *state) SafeReloadData() error { + data, err := os.ReadFile(s.statefilePath) + switch { + case errors.Is(err, os.ErrNotExist): + // Nothing to do. + return nil + case err != nil: + return status.Errorf(codes.Internal, "error reading state file: %v", err) + } + + oldVolumes := s.Volumes + oldSnapshots := s.Snapshots + + s.Volumes = nil + s.Snapshots = nil + if err := json.Unmarshal(data, &s.resources); err != nil { + s.Volumes = oldVolumes + s.Snapshots = oldSnapshots + return status.Errorf(codes.Internal, "error encoding volumes and snapshots from state file %q: %v", s.statefilePath, err) + } + return nil +} + func (s *state) GetVolumeByID(volID string) (Volume, error) { for _, volume := range s.Volumes { if volume.VolID == volID {