diff --git a/taskweaver/ces/environment.py b/taskweaver/ces/environment.py index 840002049..e5f30ed16 100644 --- a/taskweaver/ces/environment.py +++ b/taskweaver/ces/environment.py @@ -122,10 +122,10 @@ def __init__( custom_image: Optional[str] = None, ) -> None: self.session_dict: Dict[str, EnvSession] = {} + self.client_dict: Dict[str, BlockingKernelClient] = {} self.id = get_id(prefix="env") if env_id is None else env_id self.env_dir = env_dir if env_dir is not None else os.getcwd() self.mode = env_mode - self._client: Optional[BlockingKernelClient] = None if self.mode == EnvMode.Local: self.multi_kernel_manager = TaskWeaverMultiKernelManager( @@ -396,7 +396,7 @@ def update_session_var( session.session_var.update(session_var) def stop_session(self, session_id: str) -> None: - self._clean_client() + self._clean_client(session_id) session = self._get_session(session_id) if session is None: # session not exist @@ -490,8 +490,8 @@ def _get_client( self, session_id: str, ) -> BlockingKernelClient: - if self._client is not None: - return self._client + if session_id in self.client_dict: + return self.client_dict[session_id] session = self._get_session(session_id) connection_file = self._get_connection_file(session_id, session.kernel_id) client = BlockingKernelClient(connection_file=connection_file) @@ -507,13 +507,13 @@ def _get_client( client.iopub_port = ports["iopub_port"] client.wait_for_ready(timeout=30) client.start_channels() - self._client = client + self.client_dict[session_id] = client return client - def _clean_client(self): - if self._client is not None: - self._client.stop_channels() - self._client = None + def _clean_client(self, session_id: str): + if session_id in self.client_dict: + self.client_dict[session_id].stop_channels() + del self.client_dict[session_id] def _execute_code_on_kernel( self,