Skip to content
100 changes: 13 additions & 87 deletions cmd/csi-rclone-plugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"github.com/SwissDataScienceCenter/csi-rclone/pkg/rclone"
"github.com/spf13/cobra"
"k8s.io/klog"
mountUtils "k8s.io/mount-utils"
)

var (
Expand All @@ -24,8 +23,15 @@ var (
meters []metrics.Observable
)

func exitOnError(err error) {
if err != nil {
klog.Error(err.Error())
os.Exit(1)
}
}

func init() {
flag.Set("logtostderr", "true")
exitOnError(flag.Set("logtostderr", "true"))
}

func main() {
Expand All @@ -48,34 +54,10 @@ func main() {
Use: "run",
Short: "Start the CSI driver.",
}
root.AddCommand(runCmd)
exitOnError(rclone.NodeCommandLineParameters(runCmd, &meters, &nodeID, &endpoint, &cacheDir, &cacheSize))
exitOnError(rclone.ControllerCommandLineParameters(runCmd, &meters, &nodeID, &endpoint))

runNode := &cobra.Command{
Use: "node",
Short: "Start the CSI driver node service - expected to run in a daemonset on every node.",
Run: func(cmd *cobra.Command, args []string) {
handleNode()
},
}
runNode.PersistentFlags().StringVar(&nodeID, "nodeid", "", "node id")
runNode.MarkPersistentFlagRequired("nodeid")
runNode.PersistentFlags().StringVar(&endpoint, "endpoint", "", "CSI endpoint")
runNode.MarkPersistentFlagRequired("endpoint")
runNode.PersistentFlags().StringVar(&cacheDir, "cachedir", "", "cache dir")
runNode.PersistentFlags().StringVar(&cacheSize, "cachesize", "", "cache size")
runCmd.AddCommand(runNode)
runController := &cobra.Command{
Use: "controller",
Short: "Start the CSI driver controller.",
Run: func(cmd *cobra.Command, args []string) {
handleController()
},
}
runController.PersistentFlags().StringVar(&nodeID, "nodeid", "", "node id")
runController.MarkPersistentFlagRequired("nodeid")
runController.PersistentFlags().StringVar(&endpoint, "endpoint", "", "CSI endpoint")
runController.MarkPersistentFlagRequired("endpoint")
runCmd.AddCommand(runController)
root.AddCommand(runCmd)

versionCmd := &cobra.Command{
Use: "version",
Expand All @@ -86,7 +68,7 @@ func main() {
}
root.AddCommand(versionCmd)

root.ParseFlags(os.Args[1:])
exitOnError(root.ParseFlags(os.Args[1:]))

if metricsServerConfig.Enabled {
// Gracefully exit the metrics background servers
Expand All @@ -97,63 +79,7 @@ func main() {
go metricsServer.ListenAndServe()
}

if err := root.Execute(); err != nil {
fmt.Fprintf(os.Stderr, "%s", err.Error())
os.Exit(1)
}
exitOnError(root.Execute())

os.Exit(0)
}

func handleNode() {
err := unmountOldVols()
if err != nil {
klog.Warningf("There was an error when trying to unmount old volumes: %v", err)
}
d := rclone.NewDriver(nodeID, endpoint)
ns, err := rclone.NewNodeServer(d.CSIDriver, cacheDir, cacheSize)
if err != nil {
panic(err)
}
meters = append(meters, ns.Metrics()...)
d.WithNodeServer(ns)
err = d.Run()
if err != nil {
panic(err)
}
}

func handleController() {
d := rclone.NewDriver(nodeID, endpoint)
cs := rclone.NewControllerServer(d.CSIDriver)
meters = append(meters, cs.Metrics()...)
d.WithControllerServer(cs)
err := d.Run()
if err != nil {
panic(err)
}
}

// unmountOldVols is used to unmount volumes after a restart on a node
func unmountOldVols() error {
const mountType = "fuse.rclone"
const unmountTimeout = time.Second * 5
klog.Info("Checking for existing mounts")
mounter := mountUtils.Mounter{}
mounts, err := mounter.List()
if err != nil {
return err
}
for _, mount := range mounts {
if mount.Type != mountType {
continue
}
err := mounter.UnmountWithForce(mount.Path, unmountTimeout)
if err != nil {
klog.Warningf("Failed to unmount %s because of %v.", mount.Path, err)
continue
}
klog.Infof("Sucessfully unmounted %s", mount.Path)
}
return nil
}
101 changes: 78 additions & 23 deletions pkg/rclone/controllerserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
package rclone

import (
"context"
"sync"

"github.com/SwissDataScienceCenter/csi-rclone/pkg/metrics"
"github.com/container-storage-interface/spec/lib/go/csi"
"golang.org/x/net/context"
"github.com/prometheus/client_golang/prometheus"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"k8s.io/klog"
Expand All @@ -16,13 +19,66 @@ import (

const secretAnnotationName = "csi-rclone.dev/secretName"

type controllerServer struct {
type ControllerServer struct {
*csicommon.DefaultControllerServer
active_volumes map[string]int64
mutex sync.RWMutex
activeVolumes map[string]int64
mutex sync.RWMutex
}

func (cs *controllerServer) ValidateVolumeCapabilities(ctx context.Context, req *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) {
func NewControllerServer(csiDriver *csicommon.CSIDriver) *ControllerServer {
return &ControllerServer{
DefaultControllerServer: csicommon.NewDefaultControllerServer(csiDriver),
activeVolumes: map[string]int64{},
mutex: sync.RWMutex{},
}
}

func (cs *ControllerServer) metrics() []metrics.Observable {
var meters []metrics.Observable

meter := prometheus.NewGauge(prometheus.GaugeOpts{
Name: "csi_rclone_active_volume_count",
Help: "Number of active (Mounted) volumes.",
})
meters = append(meters,
func() { meter.Set(float64(len(cs.activeVolumes))) },
)
prometheus.MustRegister(meter)

return meters
}

func ControllerCommandLineParameters(runCmd *cobra.Command, meters *[]metrics.Observable, nodeID, endpoint *string) error {
runController := &cobra.Command{
Use: "controller",
Short: "Start the CSI driver controller.",
RunE: func(cmd *cobra.Command, args []string) error {
return Run(context.Background(),
nodeID,
endpoint,
func(csiDriver *csicommon.CSIDriver) (csi.ControllerServer, csi.NodeServer, error) {
cs := NewControllerServer(csiDriver)
*meters = append(*meters, cs.metrics()...)
return cs, nil, nil
},
func(_ context.Context) error { return nil },
)
},
}
runController.PersistentFlags().StringVar(nodeID, "nodeid", "", "node id")
if err := runController.MarkPersistentFlagRequired("nodeid"); err != nil {
return err
}
runController.PersistentFlags().StringVar(endpoint, "endpoint", "", "CSI endpoint")
if err := runController.MarkPersistentFlagRequired("endpoint"); err != nil {
return err
}

runCmd.AddCommand(runController)
return nil
}

func (cs *ControllerServer) ValidateVolumeCapabilities(_ context.Context, req *csi.ValidateVolumeCapabilitiesRequest) (*csi.ValidateVolumeCapabilitiesResponse, error) {
volId := req.GetVolumeId()
if len(volId) == 0 {
return nil, status.Error(codes.InvalidArgument, "ValidateVolumeCapabilities must be provided volume id")
Expand All @@ -33,7 +89,7 @@ func (cs *controllerServer) ValidateVolumeCapabilities(ctx context.Context, req

cs.mutex.Lock()
defer cs.mutex.Unlock()
if _, ok := cs.active_volumes[volId]; !ok {
if _, ok := cs.activeVolumes[volId]; !ok {
return nil, status.Errorf(codes.NotFound, "Volume %s not found", volId)
}
return &csi.ValidateVolumeCapabilitiesResponse{
Expand All @@ -45,18 +101,18 @@ func (cs *controllerServer) ValidateVolumeCapabilities(ctx context.Context, req
}, nil
}

// Attaching Volume
func (cs *controllerServer) ControllerPublishVolume(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) {
// ControllerPublishVolume Attaching Volume
func (cs *ControllerServer) ControllerPublishVolume(_ context.Context, _ *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ControllerPublishVolume not implemented")
}

// Detaching Volume
func (cs *controllerServer) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) {
// ControllerUnpublishVolume Detaching Volume
func (cs *ControllerServer) ControllerUnpublishVolume(_ context.Context, _ *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ControllerUnpublishVolume not implemented")
}

// Provisioning Volumes
func (cs *controllerServer) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) {
// CreateVolume Provisioning Volumes
func (cs *ControllerServer) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) {
klog.Infof("ControllerCreateVolume: called with args %+v", *req)
volumeName := req.GetName()
if len(volumeName) == 0 {
Expand All @@ -70,18 +126,18 @@ func (cs *controllerServer) CreateVolume(ctx context.Context, req *csi.CreateVol
// we don't use the size as it makes no sense for rclone. but csi drivers should succeed if
// called twice with the same capacity for the same volume and fail if called twice with
// differing capacity, so we need to remember it
volSizeBytes := int64(req.GetCapacityRange().GetRequiredBytes())
volSizeBytes := req.GetCapacityRange().GetRequiredBytes()
cs.mutex.Lock()
defer cs.mutex.Unlock()
if val, ok := cs.active_volumes[volumeName]; ok && val != volSizeBytes {
if val, ok := cs.activeVolumes[volumeName]; ok && val != volSizeBytes {
return nil, status.Errorf(codes.AlreadyExists, "Volume operation already exists for volume %s", volumeName)
}
cs.active_volumes[volumeName] = volSizeBytes
cs.activeVolumes[volumeName] = volSizeBytes

// See https://github.com/kubernetes-csi/external-provisioner/blob/v5.1.0/pkg/controller/controller.go#L75
// on how parameters from the persistent volume are parsed
// We have to pass the secret name and namespace into the context so that the node server can use them
// The external provisioner uses the secret name and namespace but it does not pass them into the request,
// The external provisioner uses the secret name and namespace, but it does not pass them into the request,
// so we read the PVC here to extract them ourselves because we may need them in the node server for decoding secrets.
pvcName, pvcNameFound := req.Parameters["csi.storage.k8s.io/pvc/name"]
pvcNamespace, pvcNamespaceFound := req.Parameters["csi.storage.k8s.io/pvc/namespace"]
Expand Down Expand Up @@ -114,29 +170,28 @@ func (cs *controllerServer) CreateVolume(ctx context.Context, req *csi.CreateVol

}

// Delete Volume
func (cs *controllerServer) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) {
func (cs *ControllerServer) DeleteVolume(_ context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) {
volId := req.GetVolumeId()
if len(volId) == 0 {
return nil, status.Error(codes.InvalidArgument, "DeteleVolume must be provided volume id")
return nil, status.Error(codes.InvalidArgument, "DeleteVolume must be provided volume id")
}
cs.mutex.Lock()
defer cs.mutex.Unlock()
delete(cs.active_volumes, volId)
delete(cs.activeVolumes, volId)

return &csi.DeleteVolumeResponse{}, nil
}

func (*controllerServer) ControllerExpandVolume(ctx context.Context, req *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) {
func (*ControllerServer) ControllerExpandVolume(_ context.Context, _ *csi.ControllerExpandVolumeRequest) (*csi.ControllerExpandVolumeResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method ControllerExpandVolume not implemented")
}

func (cs *controllerServer) ControllerGetVolume(ctx context.Context, req *csi.ControllerGetVolumeRequest) (*csi.ControllerGetVolumeResponse, error) {
func (cs *ControllerServer) ControllerGetVolume(_ context.Context, req *csi.ControllerGetVolumeRequest) (*csi.ControllerGetVolumeResponse, error) {
return &csi.ControllerGetVolumeResponse{Volume: &csi.Volume{
VolumeId: req.VolumeId,
}}, nil
}

func (cs *controllerServer) ControllerModifyVolume(ctx context.Context, req *csi.ControllerModifyVolumeRequest) (*csi.ControllerModifyVolumeResponse, error) {
func (cs *ControllerServer) ControllerModifyVolume(_ context.Context, _ *csi.ControllerModifyVolumeRequest) (*csi.ControllerModifyVolumeResponse, error) {
return &csi.ControllerModifyVolumeResponse{}, nil
}
Loading