Skip to content
This repository was archived by the owner on Sep 22, 2023. It is now read-only.

Commit fd73267

Browse files
authored
Fix inadvertent cancellation handling during app-proxy connection cleanup (#146)
* When the client (browser) closes a connection, the cancellation handler was inadvertently raising a CancelledError which prevented cleaning up the upper connection. * style: Update `ai.backend.client.cli.app` module's argument formatting to be like black
1 parent 58a984d commit fd73267

File tree

2 files changed

+42
-32
lines changed

2 files changed

+42
-32
lines changed

changes/146.fix

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix an inadvertent internal cancellation error when a client closes an app-proxy connection in the `app` command

src/ai/backend/client/cli/app.py

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,19 @@ class WSProxy:
2626
'app_name', 'protocol',
2727
'args', 'envs',
2828
'reader', 'writer',
29-
'down_task',
3029
)
3130

32-
def __init__(self, api_session: AsyncSession,
33-
session_name: str,
34-
app_name: str,
35-
protocol: str,
36-
args: MutableMapping[str, Union[None, str, List[str]]],
37-
envs: MutableMapping[str, str],
38-
reader: asyncio.StreamReader,
39-
writer: asyncio.StreamWriter):
31+
def __init__(
32+
self,
33+
api_session: AsyncSession,
34+
session_name: str,
35+
app_name: str,
36+
protocol: str,
37+
args: MutableMapping[str, Union[None, str, List[str]]],
38+
envs: MutableMapping[str, str],
39+
reader: asyncio.StreamReader,
40+
writer: asyncio.StreamWriter,
41+
) -> None:
4042
self.api_session = api_session
4143
self.session_name = session_name
4244
self.app_name = app_name
@@ -45,9 +47,8 @@ def __init__(self, api_session: AsyncSession,
4547
self.envs = envs
4648
self.reader = reader
4749
self.writer = writer
48-
self.down_task = None
4950

50-
async def run(self):
51+
async def run(self) -> None:
5152
prefix = get_naming(self.api_session.api_version, 'path')
5253
path = f"/stream/{prefix}/{self.session_name}/{self.protocol}proxy"
5354
params = {'app': self.app_name}
@@ -63,7 +64,7 @@ async def run(self):
6364
content_type="application/json")
6465
async with api_rqst.connect_websocket() as ws:
6566

66-
async def downstream():
67+
async def downstream() -> None:
6768
try:
6869
async for msg in ws:
6970
if msg.type == aiohttp.WSMsgType.ERROR:
@@ -79,17 +80,16 @@ async def downstream():
7980
except ConnectionResetError:
8081
pass # shutting down
8182
except asyncio.CancelledError:
82-
raise
83+
pass
8384
finally:
8485
self.writer.close()
85-
if hasattr(self.writer, 'wait_closed'): # Python 3.7+
86-
try:
87-
await self.writer.wait_closed()
88-
except (BrokenPipeError, IOError):
89-
# closed
90-
pass
91-
92-
self.down_task = asyncio.ensure_future(downstream())
86+
try:
87+
await self.writer.wait_closed()
88+
except (BrokenPipeError, IOError):
89+
# closed
90+
pass
91+
92+
down_task = asyncio.create_task(downstream())
9393
try:
9494
while True:
9595
chunk = await self.reader.read(DEFAULT_CHUNK_SIZE)
@@ -101,11 +101,11 @@ async def downstream():
101101
except asyncio.CancelledError:
102102
raise
103103
finally:
104-
if not self.down_task.done():
105-
await self.down_task
106-
self.down_task = None
104+
if not down_task.done():
105+
down_task.cancel()
106+
await down_task
107107

108-
async def write_error(self, msg):
108+
async def write_error(self, msg: aiohttp.WSMessage) -> None:
109109
if isinstance(msg.data, bytes):
110110
error_msg = msg.data.decode('utf8')
111111
else:
@@ -138,11 +138,17 @@ class ProxyRunnerContext:
138138
local_server: Optional[asyncio.AbstractServer]
139139
exit_code: int
140140

141-
def __init__(self, host: str, port: int,
142-
session_name: str, app_name: str, *,
143-
protocol: str = 'http',
144-
args: Sequence[str] = None,
145-
envs: Sequence[str] = None) -> None:
141+
def __init__(
142+
self,
143+
host: str,
144+
port: int,
145+
session_name: str,
146+
app_name: str,
147+
*,
148+
protocol: str = 'http',
149+
args: Sequence[str] = None,
150+
envs: Sequence[str] = None,
151+
) -> None:
146152
self.host = host
147153
self.port = port
148154
self.session_name = session_name
@@ -180,8 +186,11 @@ def __init__(self, host: str, port: int,
180186
else:
181187
self.envs[split[0]] = ''
182188

183-
async def handle_connection(self, reader: asyncio.StreamReader,
184-
writer: asyncio.StreamWriter) -> None:
189+
async def handle_connection(
190+
self,
191+
reader: asyncio.StreamReader,
192+
writer: asyncio.StreamWriter,
193+
) -> None:
185194
assert self.api_session is not None
186195
p = WSProxy(self.api_session, self.session_name,
187196
self.app_name, self.protocol,

0 commit comments

Comments
 (0)