@@ -41,9 +41,12 @@ const DockerIMType IMType = "docker"
4141
4242type DockerIMConfig struct {
4343 DockerImageName string
44+ GpuManufacturer string
4445 HostOrchestratorPort int
4546}
4647
48+ const gpuManufacturerNvidia = "nvidia"
49+
4750const (
4851 dockerLabelCreatedBy = "created_by"
4952 dockerLabelKeyManagedBy = "managed_by"
@@ -66,11 +69,14 @@ const (
6669 DeleteHostOPType OPType = "deletehost"
6770)
6871
69- func NewDockerInstanceManager (cfg Config , cli * client.Client ) * DockerInstanceManager {
72+ func NewDockerInstanceManager (cfg Config , cli * client.Client ) (* DockerInstanceManager , error ) {
73+ if cfg .Docker .GpuManufacturer != "" && cfg .Docker .GpuManufacturer != gpuManufacturerNvidia {
74+ return nil , fmt .Errorf ("unsupported GPU manufacturer: %q" , cfg .Docker .GpuManufacturer )
75+ }
7076 return & DockerInstanceManager {
7177 Config : cfg ,
7278 Client : cli ,
73- }
79+ }, nil
7480}
7581
7682func (m * DockerInstanceManager ) ListZones () (* apiv1.ListZonesResponse , error ) {
@@ -371,6 +377,9 @@ func (m *DockerInstanceManager) createDockerContainer(ctx context.Context, user
371377 Tty : true ,
372378 Labels : dockerLabelsDict (user ),
373379 }
380+ if m .Config .Docker .GpuManufacturer == gpuManufacturerNvidia {
381+ config .Env = []string {"NVIDIA_DRIVER_CAPABILITIES=all" }
382+ }
374383 hostConfig := & container.HostConfig {
375384 Mounts : []mount.Mount {
376385 {
@@ -381,6 +390,17 @@ func (m *DockerInstanceManager) createDockerContainer(ctx context.Context, user
381390 },
382391 Privileged : true ,
383392 }
393+ if m .Config .Docker .GpuManufacturer == gpuManufacturerNvidia {
394+ hostConfig .Resources = container.Resources {
395+ DeviceRequests : []container.DeviceRequest {
396+ {
397+ Count : - 1 ,
398+ Capabilities : [][]string {{"gpu" }},
399+ },
400+ },
401+ }
402+ hostConfig .Runtime = "nvidia"
403+ }
384404 createRes , err := m .Client .ContainerCreate (ctx , config , hostConfig , nil , nil , "" )
385405 if err != nil {
386406 return "" , fmt .Errorf ("failed to create docker container: %w" , err )
0 commit comments