diff --git a/gateway_provisioners/distributed.py b/gateway_provisioners/distributed.py index ff790d4..1e2c3d1 100644 --- a/gateway_provisioners/distributed.py +++ b/gateway_provisioners/distributed.py @@ -144,7 +144,7 @@ async def launch_kernel(self, cmd: list[str], **kwargs: Any) -> KernelConnection """ self.kernel_log = None env_dict = kwargs.get("env", {}) - self.assigned_host = self._determine_next_host(env_dict) + self.assigned_host = await self.determine_next_host(env_dict) self.ip = gethostbyname(self.assigned_host) # convert to ip if host is provided self.assigned_ip = self.ip @@ -291,9 +291,15 @@ def _build_startup_command(self, cmd: list[str], **kwargs: Any) -> list[str]: startup_cmd += f" {arg}" startup_cmd += f" >> {self.kernel_log} 2>&1 & echo $!" # return the process id - return startup_cmd + async def determine_next_host(self, env_dict: dict) -> str: + """ + This is a placeholder function which can be overridden to implement + custom logic to determine next host. + """ + return self._determine_next_host(env_dict) + def _determine_next_host(self, env_dict: dict) -> str: """Simple round-robin index into list of hosts.""" remote_host = env_dict.get("KERNEL_REMOTE_HOST")