@@ -30,72 +30,80 @@ import (
3030type bindCommand struct {
3131 logger * logrus.Logger
3232 nvpciLib nvpci.Interface
33+ options bindOptions
3334}
3435
3536type bindOptions struct {
3637 all bool
3738 deviceID string
39+ hostRoot string
3840}
3941
4042// newBindCommand constructs a bind command with the specified logger
4143func newBindCommand (logger * logrus.Logger ) * cli.Command {
4244 c := bindCommand {
4345 logger : logger ,
44- nvpciLib : nvpci .New (
45- nvpci .WithLogger (logger ),
46- ),
4746 }
4847 return c .build ()
4948}
5049
5150// build the bind command
5251func (m bindCommand ) build () * cli.Command {
53- cfg := bindOptions {}
54-
55- // Create the 'bind' command
5652 c := cli.Command {
5753 Name : "bind" ,
5854 Usage : "Bind device(s) to vfio-pci driver" ,
5955 Before : func (c * cli.Context ) error {
60- return m .validateFlags (& cfg )
56+ return m .validateFlags ()
6157 },
6258 Action : func (c * cli.Context ) error {
63- return m .run (& cfg )
59+ return m .run ()
6460 },
6561 Flags : []cli.Flag {
6662 & cli.BoolFlag {
6763 Name : "all" ,
6864 Aliases : []string {"a" },
69- Destination : & cfg .all ,
65+ Destination : & m . options .all ,
7066 Usage : "Bind all NVIDIA devices to vfio-pci" ,
7167 },
7268 & cli.StringFlag {
7369 Name : "device-id" ,
7470 Aliases : []string {"d" },
75- Destination : & cfg .deviceID ,
71+ Destination : & m . options .deviceID ,
7672 Usage : "Specific device ID to bind (e.g., 0000:01:00.0)" ,
7773 },
74+ & cli.StringFlag {
75+ Name : "host-root" ,
76+ Destination : & m .options .hostRoot ,
77+ EnvVars : []string {"HOST_ROOT" },
78+ Value : "/" ,
79+ Usage : "Path to the host's root filesystem. This is used when loading the vfio-pci module." ,
80+ },
7881 },
7982 }
8083
8184 return & c
8285}
8386
84- func (m bindCommand ) validateFlags (cfg * bindOptions ) error {
85- if ! cfg . all && cfg .deviceID == "" {
87+ func (m bindCommand ) validateFlags () error {
88+ if ! m . options . all && m . options .deviceID == "" {
8689 return fmt .Errorf ("either --all or --device-id must be specified" )
8790 }
8891
89- if cfg . all && cfg .deviceID != "" {
92+ if m . options . all && m . options .deviceID != "" {
9093 return fmt .Errorf ("cannot specify both --all and --device-id" )
9194 }
9295
9396 return nil
9497}
9598
96- func (m bindCommand ) run (cfg * bindOptions ) error {
97- if cfg .deviceID != "" {
98- return m .bindDevice (cfg .deviceID )
99+ func (m bindCommand ) run () error {
100+ m .nvpciLib = nvpci .New (
101+ nvpci .WithLogger (m .logger ),
102+ nvpci .WithHostRoot (m .options .hostRoot ),
103+ )
104+
105+ if m .options .deviceID != "" {
106+ return m .bindDevice ()
99107 }
100108
101109 return m .bindAll ()
@@ -118,7 +126,8 @@ func (m bindCommand) bindAll() error {
118126 return nil
119127}
120128
121- func (m bindCommand ) bindDevice (device string ) error {
129+ func (m bindCommand ) bindDevice () error {
130+ device := m .options .deviceID
122131 nvdev , err := m .nvpciLib .GetGPUByPciBusID (device )
123132 if err != nil {
124133 return fmt .Errorf ("failed to get NVIDIA GPU device: %w" , err )
0 commit comments