@@ -26,6 +26,7 @@ import (
2626 "k8s.io/client-go/kubernetes"
2727 "k8s.io/client-go/tools/cache"
2828 "k8s.io/client-go/tools/clientcmd"
29+ "sigs.k8s.io/yaml"
2930
3031 "context"
3132 "sync"
@@ -38,6 +39,7 @@ import (
3839 "k8s.io/apimachinery/pkg/util/wait"
3940
4041 "github.com/NVIDIA/vgpu-device-manager/internal/info"
42+ "github.com/NVIDIA/vgpu-device-manager/pkg/types"
4143)
4244
4345const (
@@ -47,6 +49,13 @@ const (
4749 vGPUConfigStateLabel = "nvidia.com/vgpu.config.state"
4850 pluginStateLabel = "nvidia.com/gpu.deploy.sandbox-device-plugin"
4951 validatorStateLabel = "nvidia.com/gpu.deploy.sandbox-validator"
52+
53+ defaultHostRootMount = "/host"
54+ defaultHostNvidiaDir = "/usr/local/nvidia"
55+ defaultHostMigManagerStateFile = "/etc/systemd/system/nvidia-mig-manager.service.d/override.conf"
56+ defaultHostKubeletSystemdService = "kubelet.service"
57+
58+ migConfigDisabled = "all-disabled"
5059)
5160
5261var (
@@ -56,10 +65,24 @@ var (
5665 configFileFlag string
5766 defaultVGPUConfigFlag string
5867
68+ migPartedConfigFileFlag string
69+ hostRootMountFlag string
70+ hostNvidiaDirFlag string
71+ hostMigManagerStateFileFlag string
72+ hostKubeletSystemdServiceFlag string
73+ gpuClientsFileFlag string
74+ withRebootFlag bool
75+ withShutdownHostGPUClientsFlag bool
76+
5977 pluginDeployed string
6078 validatorDeployed string
6179)
6280
81+ type GPUClients struct {
82+ Version string `json:"version" yaml:"version"`
83+ SystemdServices []string `json:"systemd-services" yaml:"systemd-services"`
84+ }
85+
6386// SyncableVGPUConfig is used to synchronize on changes to a configuration value.
6487// That is, callers of Get() will block until a call to Set() is made.
6588// Multiple calls to Set() do not queue, meaning that only calls to Get() made
@@ -148,6 +171,70 @@ func main() {
148171 Destination : & defaultVGPUConfigFlag ,
149172 EnvVars : []string {"DEFAULT_VGPU_CONFIG" },
150173 },
174+ & cli.StringFlag {
175+ Name : "mig-parted-config-file" ,
176+ Aliases : []string {"mc" },
177+ Value : "" ,
178+ Usage : "the path to the mig-parted configuration file" ,
179+ Destination : & migPartedConfigFileFlag ,
180+ EnvVars : []string {"MIG_PARTED_CONFIG_FILE" },
181+ },
182+ & cli.StringFlag {
183+ Name : "host-root-mount" ,
184+ Aliases : []string {"m" },
185+ Value : defaultHostRootMount ,
186+ Usage : "container path where host root directory is mounted" ,
187+ Destination : & hostRootMountFlag ,
188+ EnvVars : []string {"HOST_ROOT_MOUNT" },
189+ },
190+ & cli.StringFlag {
191+ Name : "host-nvidia-dir" ,
192+ Aliases : []string {"i" },
193+ Value : defaultHostNvidiaDir ,
194+ Usage : "host path of the directory where NVIDIA managed software directory is typically located" ,
195+ Destination : & hostNvidiaDirFlag ,
196+ EnvVars : []string {"HOST_NVIDIA_DIR" },
197+ },
198+ & cli.StringFlag {
199+ Name : "host-mig-manager-state-file" ,
200+ Aliases : []string {"o" },
201+ Value : defaultHostMigManagerStateFile ,
202+ Usage : "host path where the host's systemd mig-manager state file is located" ,
203+ Destination : & hostMigManagerStateFileFlag ,
204+ EnvVars : []string {"HOST_MIG_MANAGER_STATE_FILE" },
205+ },
206+ & cli.StringFlag {
207+ Name : "host-kubelet-systemd-service" ,
208+ Aliases : []string {"k" },
209+ Value : defaultHostKubeletSystemdService ,
210+ Usage : "name of the host's 'kubelet' systemd service which may need to be shutdown/restarted across a MIG mode reconfiguration" ,
211+ Destination : & hostKubeletSystemdServiceFlag ,
212+ EnvVars : []string {"HOST_KUBELET_SYSTEMD_SERVICE" },
213+ },
214+ & cli.StringFlag {
215+ Name : "gpu-clients-file" ,
216+ Aliases : []string {"g" },
217+ Value : "" ,
218+ Usage : "the path to the file listing the GPU clients that need to be shutdown across a MIG configuration" ,
219+ Destination : & gpuClientsFileFlag ,
220+ EnvVars : []string {"GPU_CLIENTS_FILE" },
221+ },
222+ & cli.BoolFlag {
223+ Name : "with-reboot" ,
224+ Aliases : []string {"r" },
225+ Value : false ,
226+ Usage : "reboot the node if changing the MIG mode fails for any reason" ,
227+ Destination : & withRebootFlag ,
228+ EnvVars : []string {"WITH_REBOOT" },
229+ },
230+ & cli.BoolFlag {
231+ Name : "with-shutdown-host-gpu-clients" ,
232+ Aliases : []string {"w" },
233+ Value : false ,
234+ Usage : "shutdown/restart any required host GPU clients across a MIG configuration" ,
235+ Destination : & withShutdownHostGPUClientsFlag ,
236+ EnvVars : []string {"WITH_SHUTDOWN_HOST_GPU_CLIENTS" },
237+ },
151238 }
152239
153240 log .Infof ("version: %s" , c .Version )
@@ -296,6 +383,10 @@ func updateConfig(clientset *kubernetes.Clientset, selectedConfig string) error
296383 return fmt .Errorf ("unable to shutdown gpu operands: %v" , err )
297384 }
298385
386+ if err := handleMIGConfiguration (clientset , selectedConfig ); err != nil {
387+ return fmt .Errorf ("unable to handle MIG configuration: %v" , err )
388+ }
389+
299390 log .Info ("Applying the selected vGPU device configuration to the node" )
300391 err = applyConfig (selectedConfig )
301392 if err != nil {
@@ -504,3 +595,101 @@ func setNodeLabelValue(clientset *kubernetes.Clientset, label, value string) err
504595
505596 return nil
506597}
598+
599+ func handleMIGConfiguration (clientset kubernetes.Interface , selectedConfig string ) error {
600+ if err := isNVMLAvailable (); err != nil {
601+ log .Infof ("Skipping MIG configuration due to NVML error: %v, proceeding with vGPU configuration" , err )
602+ return nil
603+ }
604+
605+ migConfig , err := determineMIGConfig (selectedConfig )
606+ if err != nil {
607+ return err
608+ }
609+
610+ log .Infof ("Selected MIG configuration: %s" , migConfig )
611+ return updateMIGConfig (clientset .(* kubernetes.Clientset ), migConfig )
612+ }
613+
614+ func determineMIGConfig (selectedConfig string ) (string , error ) {
615+ vgpuType , err := types .ParseVGPUType (selectedConfig )
616+ if err != nil {
617+ return "" , fmt .Errorf ("unable to parse vGPU type: %s" , err )
618+ }
619+
620+ if vgpuType .G > 0 {
621+ migConfig , err := convertToMIGConfigFormat (selectedConfig )
622+ if err != nil {
623+ return "" , fmt .Errorf ("unable to convert vGPU type config to MIG config: %s" , err )
624+ }
625+ return migConfig , nil
626+ }
627+
628+ return migConfigDisabled , nil
629+ }
630+
631+ // convertToMIGConfigFormat converts a vGPU type string to the MIG config format.
632+ // Examples: "A100-1-5C" -> "all-1g.5gb", "A100-1-5CME" -> "all-1g.5gb.me"
633+ func convertToMIGConfigFormat (s string ) (string , error ) {
634+ vgpu , err := types .ParseVGPUType (s )
635+ if err != nil {
636+ return "" , fmt .Errorf ("failed to parse vGPU type: %v" , err )
637+ }
638+
639+ // Base format: all-{g}g.{gb}gb
640+ result := fmt .Sprintf ("all-%dg.%dgb" , vgpu .G , vgpu .GB )
641+
642+ // Add .me suffix if media extension attribute is present
643+ for _ , attr := range vgpu .Attr {
644+ if attr == types .AttributeMediaExtensions {
645+ result += ".me"
646+ break
647+ }
648+ }
649+
650+ return result , nil
651+ }
652+
653+ func updateMIGConfig (clientset * kubernetes.Clientset , migConfigValue string ) error {
654+ gpuClients , err := parseGPUCLientsFile (gpuClientsFileFlag )
655+ if err != nil {
656+ return fmt .Errorf ("error parsing host's GPU clients file: %s" , err )
657+ }
658+
659+ opts := & reconfigureMIGOptions {
660+ NodeName : nodeNameFlag ,
661+ MIGPartedConfigFile : migPartedConfigFileFlag ,
662+ SelectedMIGConfig : migConfigValue ,
663+ WithReboot : withRebootFlag ,
664+ WithShutdownHostGPUClients : withShutdownHostGPUClientsFlag ,
665+ HostRootMount : hostRootMountFlag ,
666+ HostNvidiaDir : hostNvidiaDirFlag ,
667+ HostMIGManagerStateFile : hostMigManagerStateFileFlag ,
668+ HostGPUClientServices : gpuClients .SystemdServices ,
669+ HostKubeletService : hostKubeletSystemdServiceFlag ,
670+ }
671+
672+ return reconfigureMIG (clientset , opts )
673+ }
674+
675+ func parseGPUCLientsFile (file string ) (* GPUClients , error ) {
676+ var err error
677+ var yamlBytes []byte
678+
679+ if file == "" {
680+ return & GPUClients {}, nil
681+ }
682+
683+ yamlBytes , err = os .ReadFile (file )
684+ if err != nil {
685+ return nil , fmt .Errorf ("read error: %v" , err )
686+ }
687+
688+ var clients GPUClients
689+ err = yaml .Unmarshal (yamlBytes , & clients )
690+ if err != nil {
691+ return nil , fmt .Errorf ("unmarshal error: %v" , err )
692+ }
693+
694+ return & clients , nil
695+ }
0 commit comments