@@ -8,7 +8,7 @@ from docker.utils.utils import parse_devices
88
99from compose .cli .main import main
1010from docker .api .client import APIClient as APIClient_orig
11- from docker .errors import DockerException
11+ from docker .errors import DockerException , NotFound
1212
1313import contextlib
1414import compose .cli .errors
@@ -69,10 +69,21 @@ class NvidiaAPIClient(APIClient_orig):
6969 create_container_config (image , * args , ** kwargs ))
7070
7171 if self .is_nvidia_image (image ):
72- add_nvidia_docker_to_config (container_config )
72+ nvidia_config = get_nvidia_configuration ()
73+ self .create_nvidia_volume (nvidia_config )
74+ add_nvidia_docker_to_config (container_config , nvidia_config )
7375
7476 return container_config
7577
78+ def create_nvidia_volume (self , nvidia_config ):
79+ """ Create the nvidia volume if it doesn't already exist """
80+ try :
81+ volume_name = nvidia_config ['Volumes' ][0 ].split (':' )[0 ]
82+ self .inspect_volume (volume_name )
83+ except NotFound :
84+ self .create_volume (volume_name ,
85+ driver = nvidia_config ['VolumeDriver' ])
86+
7687
7788def get_nvidia_docker_endpoint ():
7889 host = os .environ .get (NVIDIA_HOST ,
@@ -98,17 +109,12 @@ def nvidia_docker_compatible():
98109 find_executable ('nvidia-docker' ))
99110
100111
101- def add_nvidia_docker_to_config (container_config ):
112+ def add_nvidia_docker_to_config (container_config , nvidia_config ):
102113
103114 if not container_config .get ('HostConfig' , None ):
104115 container_config ['HostConfig' ] = {}
105116
106- nvidia_config = get_nvidia_configuration ()
107-
108- # Setup the Volumes
109- container_config ['HostConfig' ].setdefault ('VolumeDriver' ,
110- nvidia_config ['VolumeDriver' ])
111-
117+ # # Setup the Volumes
112118 container_config ['HostConfig' ].setdefault ('Binds' , [])
113119 container_config ['HostConfig' ]['Binds' ].extend (nvidia_config ['Volumes' ])
114120
0 commit comments