From 19ebeddc6a3967938000fcf6ba38007922fb6186 Mon Sep 17 00:00:00 2001 From: Divyansh Choudhary Date: Thu, 23 Nov 2023 19:05:37 +0530 Subject: [PATCH] Make determine_next_host an async function --- gateway_provisioners/distributed.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/gateway_provisioners/distributed.py b/gateway_provisioners/distributed.py index 2a7ef94..ead0274 100644 --- a/gateway_provisioners/distributed.py +++ b/gateway_provisioners/distributed.py @@ -143,7 +143,7 @@ async def launch_kernel(self, cmd: tyList[str], **kwargs: Any) -> KernelConnecti """ self.kernel_log = None env_dict = kwargs.get("env", {}) - self.assigned_host = self._determine_next_host(env_dict) + self.assigned_host = await self.determine_next_host(env_dict) self.ip = gethostbyname(self.assigned_host) # convert to ip if host is provided self.assigned_ip = self.ip @@ -286,9 +286,15 @@ def _build_startup_command(self, cmd: tyList[str], **kwargs: Any) -> tyList[str] startup_cmd += f" {arg}" startup_cmd += f" >> {self.kernel_log} 2>&1 & echo $!" # return the process id - return startup_cmd + async def determine_next_host(self, env_dict: dict) -> str: + """ + This is a placeholder function which can be overridden to implement + custom logic to determine next host. + """ + return self._determine_next_host(env_dict) + def _determine_next_host(self, env_dict: dict) -> str: """Simple round-robin index into list of hosts.""" remote_host = env_dict.get("KERNEL_REMOTE_HOST")