Skip to content

Commit dad3cf0

Browse files
authored
add client dict to separate different sessions (#485)
Add a dict to store different client for different sessions. This is to resolve the bug #482
2 parents 776064b + 29a8f35 commit dad3cf0

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

taskweaver/ces/environment.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,10 @@ def __init__(
122122
custom_image: Optional[str] = None,
123123
) -> None:
124124
self.session_dict: Dict[str, EnvSession] = {}
125+
self.client_dict: Dict[str, BlockingKernelClient] = {}
125126
self.id = get_id(prefix="env") if env_id is None else env_id
126127
self.env_dir = env_dir if env_dir is not None else os.getcwd()
127128
self.mode = env_mode
128-
self._client: Optional[BlockingKernelClient] = None
129129

130130
if self.mode == EnvMode.Local:
131131
self.multi_kernel_manager = TaskWeaverMultiKernelManager(
@@ -396,7 +396,7 @@ def update_session_var(
396396
session.session_var.update(session_var)
397397

398398
def stop_session(self, session_id: str) -> None:
399-
self._clean_client()
399+
self._clean_client(session_id)
400400
session = self._get_session(session_id)
401401
if session is None:
402402
# session not exist
@@ -490,8 +490,8 @@ def _get_client(
490490
self,
491491
session_id: str,
492492
) -> BlockingKernelClient:
493-
if self._client is not None:
494-
return self._client
493+
if session_id in self.client_dict:
494+
return self.client_dict[session_id]
495495
session = self._get_session(session_id)
496496
connection_file = self._get_connection_file(session_id, session.kernel_id)
497497
client = BlockingKernelClient(connection_file=connection_file)
@@ -507,13 +507,13 @@ def _get_client(
507507
client.iopub_port = ports["iopub_port"]
508508
client.wait_for_ready(timeout=30)
509509
client.start_channels()
510-
self._client = client
510+
self.client_dict[session_id] = client
511511
return client
512512

513-
def _clean_client(self):
514-
if self._client is not None:
515-
self._client.stop_channels()
516-
self._client = None
513+
def _clean_client(self, session_id: str):
514+
if session_id in self.client_dict:
515+
self.client_dict[session_id].stop_channels()
516+
del self.client_dict[session_id]
517517

518518
def _execute_code_on_kernel(
519519
self,

0 commit comments

Comments
 (0)