Skip to content
This repository was archived by the owner on Sep 22, 2023. It is now read-only.

Commit e5f632e

Browse files
authored
Refine app proxy support (#90)
* fix(cli.app): Add explicit failure handling for app discovery failure * fix(cli.app): Ensure repeated proxy_ctx invocation has the same result
1 parent 72bd87b commit e5f632e

File tree

1 file changed

+30
-19
lines changed
  • src/ai/backend/client/cli

1 file changed

+30
-19
lines changed

src/ai/backend/client/cli/app.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import asyncio
22
import json
3+
import os
34
import shlex
45
import signal
6+
import sys
57
from typing import (
68
Union, Optional,
79
MutableMapping, Dict,
@@ -12,7 +14,7 @@
1214
import click
1315

1416
from . import main
15-
from .pretty import print_info, print_warn, print_error
17+
from .pretty import print_info, print_warn, print_fail, print_error
1618
from ..config import DEFAULT_CHUNK_SIZE
1719
from ..request import Request
1820
from ..session import AsyncSession
@@ -120,6 +122,7 @@ class ProxyRunnerContext:
120122
'protocol', 'host', 'port',
121123
'args', 'envs',
122124
'api_session', 'local_server',
125+
'exit_code',
123126
)
124127

125128
session_name: str
@@ -131,6 +134,7 @@ class ProxyRunnerContext:
131134
envs: Dict[str, str]
132135
api_session: Optional[AsyncSession]
133136
local_server: Optional[asyncio.AbstractServer]
137+
exit_code: int
134138

135139
def __init__(self, host: str, port: int,
136140
session_name: str, app_name: str, *,
@@ -145,6 +149,7 @@ def __init__(self, host: str, port: int,
145149

146150
self.api_session = None
147151
self.local_server = None
152+
self.exit_code = 0
148153

149154
self.args, self.envs = {}, {}
150155
if len(args) > 0:
@@ -187,24 +192,26 @@ async def handle_connection(self, reader: asyncio.StreamReader,
187192
print_error(e)
188193

189194
async def __aenter__(self) -> None:
195+
self.exit_code = 0
190196
self.api_session = AsyncSession()
191197
await self.api_session.__aenter__()
192-
self.local_server = await asyncio.start_server(
193-
self.handle_connection, self.host, self.port)
194198

195199
user_url_template = "{protocol}://{host}:{port}"
196-
try:
197-
compute_session = self.api_session.ComputeSession(self.session_name)
198-
data = await compute_session.stream_app_info(self.app_name)
199-
if 'url_template' in data.keys():
200-
user_url_template = data['url_template']
201-
except:
202-
if self.app_name == 'vnc-web':
203-
user_url_template = \
204-
"{protocol}://{host}:{port}/vnc.html" \
205-
"?host={host}&port={port}" \
206-
"&password=backendai&autoconnect=true"
200+
compute_session = self.api_session.ComputeSession(self.session_name)
201+
all_apps = await compute_session.stream_app_info()
202+
for app_info in all_apps:
203+
if app_info['name'] == self.app_name:
204+
if 'url_template' in app_info.keys():
205+
user_url_template = app_info['url_template']
206+
break
207+
else:
208+
print_fail(f'The app "{self.app_name}" is not supported by the session.')
209+
self.exit_code = 1
210+
os.kill(0, signal.SIGINT)
211+
return
207212

213+
self.local_server = await asyncio.start_server(
214+
self.handle_connection, self.host, self.port)
208215
user_url = user_url_template.format(
209216
protocol=self.protocol,
210217
host=self.host,
@@ -220,13 +227,16 @@ async def __aenter__(self) -> None:
220227
'to connect with the CLI app proxy.')
221228

222229
async def __aexit__(self, *exc_info) -> None:
223-
print_info("Shutting down....")
224-
self.local_server.close()
225-
await self.local_server.wait_closed()
230+
if self.local_server is not None:
231+
print_info("Shutting down....")
232+
self.local_server.close()
233+
await self.local_server.wait_closed()
226234
await self.api_session.__aexit__(*exc_info)
227235
assert self.api_session.closed
228-
print_info("The local proxy to \"{}\" has terminated."
229-
.format(self.app_name))
236+
if self.local_server is not None:
237+
print_info("The local proxy to \"{}\" has terminated."
238+
.format(self.app_name))
239+
self.local_server = None
230240

231241

232242
@main.command()
@@ -266,6 +276,7 @@ def app(session_name, app, protocol, bind, arg, env):
266276
)
267277
stop_signals = {signal.SIGINT, signal.SIGTERM}
268278
asyncio_run_forever(proxy_ctx, stop_signals=stop_signals)
279+
sys.exit(proxy_ctx.exit_code)
269280

270281

271282
@main.command()

0 commit comments

Comments
 (0)