@@ -26,6 +26,7 @@ import (
2626 "k8s.io/client-go/kubernetes"
2727 "k8s.io/client-go/tools/cache"
2828 "k8s.io/client-go/tools/clientcmd"
29+ "sigs.k8s.io/yaml"
2930
3031 "context"
3132 "sync"
@@ -37,7 +38,13 @@ import (
3738 "k8s.io/apimachinery/pkg/fields"
3839 "k8s.io/apimachinery/pkg/util/wait"
3940
41+ migpartedv1 "github.com/NVIDIA/mig-parted/api/spec/v1"
42+ migtypes "github.com/NVIDIA/mig-parted/pkg/types"
43+
44+ v1 "github.com/NVIDIA/vgpu-device-manager/api/spec/v1"
45+ "github.com/NVIDIA/vgpu-device-manager/cmd/nvidia-vgpu-dm/assert"
4046 "github.com/NVIDIA/vgpu-device-manager/internal/info"
47+ "github.com/NVIDIA/vgpu-device-manager/pkg/types"
4148)
4249
4350const (
@@ -47,6 +54,13 @@ const (
4754 vGPUConfigStateLabel = "nvidia.com/vgpu.config.state"
4855 pluginStateLabel = "nvidia.com/gpu.deploy.sandbox-device-plugin"
4956 validatorStateLabel = "nvidia.com/gpu.deploy.sandbox-validator"
57+
58+ defaultHostRootMount = "/host"
59+ defaultHostNvidiaDir = "/usr/local/nvidia"
60+ defaultHostMigManagerStateFile = "/etc/systemd/system/nvidia-mig-manager.service.d/override.conf"
61+ defaultHostKubeletSystemdService = "kubelet.service"
62+
63+ migConfigDisabled = "all-disabled"
5064)
5165
5266var (
@@ -56,10 +70,23 @@ var (
5670 configFileFlag string
5771 defaultVGPUConfigFlag string
5872
73+ hostRootMountFlag string
74+ hostNvidiaDirFlag string
75+ hostMigManagerStateFileFlag string
76+ hostKubeletSystemdServiceFlag string
77+ gpuClientsFileFlag string
78+ withRebootFlag bool
79+ withShutdownHostGPUClientsFlag bool
80+
5981 pluginDeployed string
6082 validatorDeployed string
6183)
6284
85+ type GPUClients struct {
86+ Version string `json:"version" yaml:"version"`
87+ SystemdServices []string `json:"systemd-services" yaml:"systemd-services"`
88+ }
89+
6390// SyncableVGPUConfig is used to synchronize on changes to a configuration value.
6491// That is, callers of Get() will block until a call to Set() is made.
6592// Multiple calls to Set() do not queue, meaning that only calls to Get() made
@@ -148,6 +175,62 @@ func main() {
148175 Destination : & defaultVGPUConfigFlag ,
149176 EnvVars : []string {"DEFAULT_VGPU_CONFIG" },
150177 },
178+ & cli.StringFlag {
179+ Name : "host-root-mount" ,
180+ Aliases : []string {"m" },
181+ Value : defaultHostRootMount ,
182+ Usage : "container path where host root directory is mounted" ,
183+ Destination : & hostRootMountFlag ,
184+ EnvVars : []string {"HOST_ROOT_MOUNT" },
185+ },
186+ & cli.StringFlag {
187+ Name : "host-nvidia-dir" ,
188+ Aliases : []string {"i" },
189+ Value : defaultHostNvidiaDir ,
190+ Usage : "host path of the directory where NVIDIA managed software directory is typically located" ,
191+ Destination : & hostNvidiaDirFlag ,
192+ EnvVars : []string {"HOST_NVIDIA_DIR" },
193+ },
194+ & cli.StringFlag {
195+ Name : "host-mig-manager-state-file" ,
196+ Aliases : []string {"o" },
197+ Value : defaultHostMigManagerStateFile ,
198+ Usage : "host path where the host's systemd mig-manager state file is located" ,
199+ Destination : & hostMigManagerStateFileFlag ,
200+ EnvVars : []string {"HOST_MIG_MANAGER_STATE_FILE" },
201+ },
202+ & cli.StringFlag {
203+ Name : "host-kubelet-systemd-service" ,
204+ Aliases : []string {"k" },
205+ Value : defaultHostKubeletSystemdService ,
206+ Usage : "name of the host's 'kubelet' systemd service which may need to be shutdown/restarted across a MIG mode reconfiguration" ,
207+ Destination : & hostKubeletSystemdServiceFlag ,
208+ EnvVars : []string {"HOST_KUBELET_SYSTEMD_SERVICE" },
209+ },
210+ & cli.StringFlag {
211+ Name : "gpu-clients-file" ,
212+ Aliases : []string {"g" },
213+ Value : "" ,
214+ Usage : "the path to the file listing the GPU clients that need to be shutdown across a MIG configuration" ,
215+ Destination : & gpuClientsFileFlag ,
216+ EnvVars : []string {"GPU_CLIENTS_FILE" },
217+ },
218+ & cli.BoolFlag {
219+ Name : "with-reboot" ,
220+ Aliases : []string {"r" },
221+ Value : false ,
222+ Usage : "reboot the node if changing the MIG mode fails for any reason" ,
223+ Destination : & withRebootFlag ,
224+ EnvVars : []string {"WITH_REBOOT" },
225+ },
226+ & cli.BoolFlag {
227+ Name : "with-shutdown-host-gpu-clients" ,
228+ Aliases : []string {"w" },
229+ Value : false ,
230+ Usage : "shutdown/restart any required host GPU clients across a MIG configuration" ,
231+ Destination : & withShutdownHostGPUClientsFlag ,
232+ EnvVars : []string {"WITH_SHUTDOWN_HOST_GPU_CLIENTS" },
233+ },
151234 }
152235
153236 log .Infof ("version: %s" , c .Version )
@@ -296,6 +379,10 @@ func updateConfig(clientset *kubernetes.Clientset, selectedConfig string) error
296379 return fmt .Errorf ("unable to shutdown gpu operands: %v" , err )
297380 }
298381
382+ if err := handleMIGConfiguration (clientset , selectedConfig ); err != nil {
383+ return fmt .Errorf ("unable to handle MIG configuration: %v" , err )
384+ }
385+
299386 log .Info ("Applying the selected vGPU device configuration to the node" )
300387 err = applyConfig (selectedConfig )
301388 if err != nil {
@@ -504,3 +591,158 @@ func setNodeLabelValue(clientset *kubernetes.Clientset, label, value string) err
504591
505592 return nil
506593}
594+
595+ func handleMIGConfiguration (clientset kubernetes.Interface , selectedConfig string ) error {
596+ if err := isNVMLAvailable (); err != nil {
597+ log .Infof ("Skipping MIG configuration due to NVML error: %v, proceeding with vGPU configuration" , err )
598+ return nil
599+ }
600+
601+ migConfig , err := determineMIGConfig (selectedConfig )
602+ if err != nil {
603+ return err
604+ }
605+
606+ configFile , err := saveMIGConfigToTempFile (migConfig )
607+ if err != nil {
608+ return fmt .Errorf ("failed to save MIG config to temporary file: %w" , err )
609+ }
610+
611+ return updateMIGConfig (clientset .(* kubernetes.Clientset ), configFile , selectedConfig )
612+ }
613+
614+ func determineMIGConfig (selectedConfig string ) (* migpartedv1.Spec , error ) {
615+ f := & assert.Flags {
616+ ConfigFile : configFileFlag ,
617+ SelectedConfig : selectedConfig ,
618+ ValidConfig : false , // We don't need to validate the config here, just parse it.
619+ }
620+
621+ log .Debugf ("Parsing vGPU config file..." )
622+ spec , err := assert .ParseConfigFile (f )
623+ if err != nil {
624+ return nil , fmt .Errorf ("error parsing config file: %v" , err )
625+ }
626+
627+ log .Debugf ("Selecting specific vGPU config..." )
628+ vgpuConfig , err := assert .GetSelectedVGPUConfig (f , spec )
629+ if err != nil {
630+ return nil , fmt .Errorf ("error selecting VGPU config: %v" , err )
631+ }
632+
633+ return convertToMIGConfig (vgpuConfig , selectedConfig )
634+ }
635+
636+ func convertToMIGConfig (vgpuConfig v1.VGPUConfigSpecSlice , selectedConfig string ) (* migpartedv1.Spec , error ) {
637+ var migConfigSpecs migpartedv1.MigConfigSpecSlice
638+
639+ for _ , vgpuSpec := range vgpuConfig {
640+ migSpec := migpartedv1.MigConfigSpec {
641+ DeviceFilter : vgpuSpec .DeviceFilter ,
642+ Devices : vgpuSpec .Devices ,
643+ MigDevices : make (migtypes.MigConfig ),
644+ }
645+
646+ migEnabled := false
647+ for vgpuType := range vgpuSpec .VGPUDevices {
648+ vgpu , err := types .ParseVGPUType (vgpuType )
649+ if err != nil {
650+ return nil , fmt .Errorf ("failed to parse vGPU type %s: %w" , vgpuType , err )
651+ }
652+
653+ if vgpu .G > 0 {
654+ migEnabled = true
655+ migProfile := fmt .Sprintf ("%dg.%dgb" , vgpu .G , vgpu .GB )
656+ for _ , attr := range vgpu .Attr {
657+ if attr == types .AttributeMediaExtensions {
658+ migProfile += ".me"
659+ break
660+ }
661+ }
662+ migSpec .MigDevices [migProfile ] = vgpuSpec .VGPUDevices [vgpuType ]
663+ }
664+ }
665+
666+ migSpec .MigEnabled = migEnabled
667+
668+ migConfigSpecs = append (migConfigSpecs , migSpec )
669+ }
670+
671+ spec := & migpartedv1.Spec {
672+ Version : migpartedv1 .Version ,
673+ MigConfigs : map [string ]migpartedv1.MigConfigSpecSlice {
674+ selectedConfig : migConfigSpecs ,
675+ },
676+ }
677+ return spec , nil
678+ }
679+
680+ func saveMIGConfigToTempFile (migConfig * migpartedv1.Spec ) (string , error ) {
681+ tempFile , err := os .CreateTemp ("" , "mig-parted-config-*.yaml" )
682+ if err != nil {
683+ return "" , fmt .Errorf ("failed to create temporary file: %w" , err )
684+ }
685+ defer tempFile .Close ()
686+
687+ yamlData , err := yaml .Marshal (migConfig )
688+ if err != nil {
689+ return "" , fmt .Errorf ("failed to marshal MIG config to YAML: %w" , err )
690+ }
691+
692+ if _ , err := tempFile .Write (yamlData ); err != nil {
693+ return "" , fmt .Errorf ("failed to write YAML data to temporary file: %w" , err )
694+ }
695+
696+ return tempFile .Name (), nil
697+ }
698+
699+ func updateMIGConfig (clientset * kubernetes.Clientset , migPartedConfigFile , selectedConfig string ) error {
700+
701+ defer func () {
702+ if err := os .Remove (migPartedConfigFile ); err != nil {
703+ log .Errorf ("Failed to remove temporary mig-parted config file %s: %v" , migPartedConfigFile , err )
704+ }
705+ }()
706+
707+ gpuClients , err := parseGPUCLientsFile (gpuClientsFileFlag )
708+ if err != nil {
709+ return fmt .Errorf ("error parsing host's GPU clients file: %w" , err )
710+ }
711+
712+ opts := & reconfigureMIGOptions {
713+ NodeName : nodeNameFlag ,
714+ MIGPartedConfigFile : migPartedConfigFile ,
715+ SelectedMIGConfig : selectedConfig ,
716+ WithReboot : withRebootFlag ,
717+ WithShutdownHostGPUClients : withShutdownHostGPUClientsFlag ,
718+ HostRootMount : hostRootMountFlag ,
719+ HostNvidiaDir : hostNvidiaDirFlag ,
720+ HostMIGManagerStateFile : hostMigManagerStateFileFlag ,
721+ HostGPUClientServices : gpuClients .SystemdServices ,
722+ HostKubeletService : hostKubeletSystemdServiceFlag ,
723+ }
724+
725+ return reconfigureMIG (clientset , opts )
726+ }
727+
728+ func parseGPUCLientsFile (file string ) (* GPUClients , error ) {
729+ var err error
730+ var yamlBytes []byte
731+
732+ if file == "" {
733+ return & GPUClients {}, nil
734+ }
735+
736+ yamlBytes , err = os .ReadFile (file )
737+ if err != nil {
738+ return nil , fmt .Errorf ("read error: %w" , err )
739+ }
740+
741+ var clients GPUClients
742+ err = yaml .Unmarshal (yamlBytes , & clients )
743+ if err != nil {
744+ return nil , fmt .Errorf ("unmarshal error: %w" , err )
745+ }
746+
747+ return & clients , nil
748+ }
0 commit comments