4343
4444
4545if TYPE_CHECKING :
46+ from docker import DockerClient
4647 from docker .models .containers import Container
4748
4849log : logging .Logger = logging .getLogger (__name__ )
@@ -94,6 +95,30 @@ def has_docker() -> bool:
9495 return False
9596
9697
98+ def ensure_network (client : Optional ["DockerClient" ] = None ) -> None :
99+ """
100+ This creates the torchx docker network. Multi-process safe.
101+ """
102+ import filelock
103+ from docker .errors import APIError
104+
105+ if client is None :
106+ import docker
107+
108+ client = docker .from_env ()
109+
110+ lock_path = os .path .join (tempfile .gettempdir (), "torchx_docker_network_lock" )
111+
112+ # Docker networks.create check_duplicate has a race condition so we need
113+ # to do client side locking to ensure only one network is created.
114+ with filelock .FileLock (lock_path , timeout = 10 ):
115+ try :
116+ client .networks .create (name = NETWORK , driver = "bridge" , check_duplicate = True )
117+ except APIError as e :
118+ if "already exists" not in str (e ):
119+ raise
120+
121+
97122class DockerOpts (TypedDict , total = False ):
98123 copy_env : Optional [List [str ]]
99124
@@ -145,24 +170,6 @@ class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
145170 def __init__ (self , session_name : str ) -> None :
146171 super ().__init__ ("docker" , session_name )
147172
148- def _ensure_network (self ) -> None :
149- import filelock
150- from docker .errors import APIError
151-
152- client = self ._docker_client
153- lock_path = os .path .join (tempfile .gettempdir (), "torchx_docker_network_lock" )
154-
155- # Docker networks.create check_duplicate has a race condition so we need
156- # to do client side locking to ensure only one network is created.
157- with filelock .FileLock (lock_path , timeout = 10 ):
158- try :
159- client .networks .create (
160- name = NETWORK , driver = "bridge" , check_duplicate = True
161- )
162- except APIError as e :
163- if "already exists" not in str (e ):
164- raise
165-
166173 def schedule (self , dryrun_info : AppDryRunInfo [DockerJob ]) -> str :
167174 client = self ._docker_client
168175
@@ -180,7 +187,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[DockerJob]) -> str:
180187 except Exception as e :
181188 log .warning (f"failed to pull image { image } , falling back to local: { e } " )
182189
183- self ._ensure_network ( )
190+ ensure_network ( self ._docker_client )
184191
185192 for container in req .containers :
186193 client .containers .run (
0 commit comments