@@ -61,8 +61,12 @@ class NvidiaAPIClient(APIClient_orig):
6161 """ # noqa: E501
6262
6363 def is_nvidia_image (self , image ):
64- return (self .inspect_image (image ).get ('Config' , {}).get ('Labels' , {}).
65- get ('com.nvidia.volumes.needed' , None ) == 'nvidia_driver' )
64+ labels = self .inspect_image (image ).get ('Config' , {}).get ('Labels' , {})
65+ if labels :
66+ return (labels .get ('com.nvidia.volumes.needed' ,
67+ None ) == 'nvidia_driver' )
68+ else :
69+ return False
6670
6771 def create_container_config (self , image , * args , ** kwargs ):
6872 container_config = (super (NvidiaAPIClient , self ).
@@ -89,7 +93,7 @@ def get_nvidia_docker_endpoint():
8993 host = os .environ .get (NVIDIA_HOST ,
9094 "http://{}:{}" .format (NVIDIA_DEFAULT_HOST ,
9195 NVIDIA_DEFAULT_PORT ))
92- return host + '/docker/cli/json'
96+ return host + '/docker/cli/json'
9397
9498
9599def get_nvidia_configuration ():
@@ -120,6 +124,8 @@ def add_nvidia_docker_to_config(container_config, nvidia_config):
120124
121125 # Get nvidia control devices
122126 devices = container_config ['HostConfig' ].get ('Devices' , [])
127+ if devices is None :
128+ devices = []
123129 # suport both '0 1' and '0, 1' formats, just like nvidia-docker
124130 gpu_isolation = os .getenv ('NV_GPU' , '' ).replace (',' , ' ' ).split ()
125131 pattern = re .compile (r'/nvidia([0-9]+)$' )
0 commit comments