Skip to content

Commit 739983b

Browse files
committed
Begin summary command + misc (types, client, vendored sse_client, etc.)
1 parent 7ab9e5a commit 739983b

File tree

5 files changed

+337
-2
lines changed

5 files changed

+337
-2
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""
2+
"""
3+
4+
from typing import Optional
5+
6+
import httpx
7+
8+
from .utils._headers import build_hf_headers
9+
from .utils._sse_client import SSEClient
10+
from ._hot_reloading_types import ApiGetReloadEventSourceData
11+
from ._hot_reloading_types import ApiGetReloadRequest
12+
13+
14+
HOT_RELOADING_PORT = 7887
15+
16+
17+
class ReloadClient:
18+
def __init__(
19+
self,
20+
*,
21+
host: str,
22+
subdomain: str,
23+
replica_hash: str,
24+
token: Optional[str],
25+
):
26+
base_host = host.replace(subdomain, f"{subdomain}--{HOT_RELOADING_PORT}")
27+
self.client = httpx.Client(
28+
base_url=f"{base_host}/--replicas/+{replica_hash}",
29+
headers=build_hf_headers(token=token),
30+
)
31+
32+
def get_reload(self, reload_id: str):
33+
req = ApiGetReloadRequest(reloadId=reload_id)
34+
with self.client.stream('POST', '/get-reload', json=req.model_dump()) as res:
35+
assert res.status_code == 200, res.status_code # TODO: Raise specific error ? Return ?
36+
for event in SSEClient(res.iter_bytes()).events():
37+
if event.event == 'message':
38+
yield ApiGetReloadEventSourceData.model_validate_json(event.data)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""
2+
"""
3+
from typing import Literal
4+
5+
from pydantic import BaseModel
6+
7+
8+
class ReloadRegion(BaseModel):
9+
startLine: int
10+
startCol: int
11+
endLine: int
12+
endCol: int
13+
14+
15+
class ReloadOperationObject(BaseModel):
16+
kind: Literal['add', 'update', 'delete']
17+
region: ReloadRegion
18+
objectType: str
19+
objectName: str
20+
21+
22+
class ReloadOperationRun(BaseModel):
23+
kind: Literal['run']
24+
region: ReloadRegion
25+
codeLines: str
26+
stdout: str | None = None
27+
stderr: str | None = None
28+
29+
30+
class ReloadOperationException(BaseModel):
31+
kind: Literal['exception']
32+
region: ReloadRegion
33+
traceback: str
34+
35+
36+
class ReloadOperationError(BaseModel):
37+
kind: Literal['error']
38+
traceback: str
39+
40+
41+
class ReloadOperationUI(BaseModel):
42+
kind: Literal['ui']
43+
updated: bool
44+
45+
46+
class ApiCreateReloadRequest(BaseModel):
47+
filepath: str
48+
contents: str
49+
reloadId: str | None = None
50+
51+
52+
class ApiCreateReloadResponseSuccess(BaseModel):
53+
status: Literal['created']
54+
reloadId: str
55+
56+
57+
class ApiCreateReloadResponseError(BaseModel):
58+
status: Literal['alreadyReloading', 'fileNotFound']
59+
60+
61+
class ApiCreateReloadResponse(BaseModel):
62+
res: ApiCreateReloadResponseError | ApiCreateReloadResponseSuccess
63+
64+
65+
class ApiGetReloadRequest(BaseModel):
66+
reloadId: str
67+
68+
69+
class ApiGetReloadEventSourceData(BaseModel):
70+
data: ReloadOperationError \
71+
| ReloadOperationException \
72+
| ReloadOperationObject \
73+
| ReloadOperationRun \
74+
| ReloadOperationUI \
75+
76+
77+
class ApiGetStatusRequest(BaseModel):
78+
revision: str
79+
80+
81+
class ApiGetStatusResponse(BaseModel):
82+
reloading: bool
83+
uncommited: list[str]
84+
85+
86+
class ApiFetchContentsRequest(BaseModel):
87+
filepath: str
88+
89+
90+
class ApiFetchContentsResponseError(BaseModel):
91+
status: Literal['fileNotFound']
92+
93+
94+
class ApiFetchContentsResponseSuccess(BaseModel):
95+
status: Literal['ok']
96+
contents: str
97+
98+
99+
class ApiFetchContentsResponse(BaseModel):
100+
res: ApiFetchContentsResponseError | ApiFetchContentsResponseSuccess

src/huggingface_hub/_space_api.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from dataclasses import dataclass
1616
from datetime import datetime
1717
from enum import Enum
18-
from typing import Optional
18+
from typing import Literal, Optional, Union
1919

2020
from huggingface_hub.utils import parse_datetime
2121

@@ -99,6 +99,18 @@ class SpaceStorage(str, Enum):
9999
LARGE = "large"
100100

101101

102+
@dataclass
103+
class SpaceHotReloading:
104+
status: Literal["created", "canceled"]
105+
replica_statuses: list[tuple[str, str]] # See hot_reloading.types.ApiCreateReloadResponse.res.status
106+
raw: dict
107+
108+
def __init__(self, data: dict) -> None:
109+
self.status = data["status"]
110+
self.replica_statuses = data["replicaStatuses"]
111+
self.raw = data
112+
113+
102114
@dataclass
103115
class SpaceRuntime:
104116
"""
@@ -128,6 +140,7 @@ class SpaceRuntime:
128140
requested_hardware: Optional[SpaceHardware]
129141
sleep_time: Optional[int]
130142
storage: Optional[SpaceStorage]
143+
hot_reloading: Optional[SpaceHotReloading]
131144
raw: dict
132145

133146
def __init__(self, data: dict) -> None:
@@ -136,6 +149,7 @@ def __init__(self, data: dict) -> None:
136149
self.requested_hardware = data.get("hardware", {}).get("requested")
137150
self.sleep_time = data.get("gcTimeout")
138151
self.storage = data.get("storage")
152+
self.hot_reloading = SpaceHotReloading(raw_hr) if (raw_hr := data.get("hotReloading")) is not None else None
139153
self.raw = data
140154

141155

src/huggingface_hub/cli/spaces.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from huggingface_hub.cli import _cli_utils
3737
from huggingface_hub.errors import CLIError, RepositoryNotFoundError, RevisionNotFoundError
3838
from huggingface_hub.file_download import hf_hub_download
39-
from huggingface_hub.hf_api import ExpandSpaceProperty_T, SpaceSort_T
39+
from huggingface_hub.hf_api import ExpandSpaceProperty_T, HfApi, SpaceSort_T
4040
from huggingface_hub.utils import are_progress_bars_disabled, disable_progress_bars, enable_progress_bars
4141

4242
from ._cli_utils import (
@@ -210,3 +210,50 @@ def spaces_hot_reload(
210210
# hot-reloading summary
211211
# TODO
212212
pass
213+
214+
215+
def _spaces_hot_reloading_summary(
216+
api: HfApi,
217+
space_id: str,
218+
commit_sha: str,
219+
token: Optional[str],
220+
) -> None:
221+
from huggingface_hub._hot_reloading_client import ReloadClient
222+
223+
space_info = api.space_info(space_id)
224+
if (runtime := space_info.runtime) is None:
225+
raise CLIError(f"Unable to read SpaceRuntime from {space_id} infos")
226+
if (hot_reloading := runtime.hot_reloading) is None:
227+
raise CLIError(f"Space {space_id} current running version has not been hot-reloaded")
228+
if hot_reloading.status != "created":
229+
typer.echo("...")
230+
return
231+
232+
if (space_host := space_info.host) is None:
233+
raise CLIError(f"Unexpected None host on hotReloaded Space")
234+
if (space_subdomain := space_info.subdomain) is None:
235+
raise CLIError(f"Unexpected None subdomain on hotReloaded Space")
236+
237+
clients = [ReloadClient(
238+
host=space_host,
239+
subdomain=space_subdomain,
240+
replica_hash=hash,
241+
token=token,
242+
) for hash, _ in hot_reloading.replica_statuses]
243+
244+
...
245+
# Fetch first client (display replica hash if multiple)
246+
# Compare others
247+
# Display final info (Success, Hot-reloading contains errors, etc.)
248+
249+
250+
251+
@spaces_hot_reloading_cli.command("summary")
252+
def spaces_hot_reloading_summary(
253+
space_id: Annotated[str, typer.Argument(help="The space ID (e.g. `username/repo-name`).")],
254+
commit_sha: Annotated[str, typer.Argument(help="...")],
255+
token: TokenOpt = None,
256+
):
257+
""" Description """
258+
api = get_hf_api(token=token)
259+
_spaces_hot_reloading_summary(api, space_id, commit_sha, token)
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""
2+
Server Side Events (SSE) client for Python.
3+
4+
Provides a generator of SSE received through an existing HTTP response.
5+
"""
6+
7+
import logging
8+
9+
__author__ = 'Maxime Petazzoni <maxime.petazzoni@bulix.org>'
10+
__email__ = 'maxime.petazzoni@bulix.org'
11+
__all__ = ['SSEClient']
12+
13+
_FIELD_SEPARATOR = ':'
14+
15+
16+
class SSEClient(object):
17+
"""Implementation of a SSE client.
18+
19+
See http://www.w3.org/TR/2009/WD-eventsource-20091029/ for the
20+
specification.
21+
"""
22+
23+
def __init__(self, event_source, char_enc='utf-8'):
24+
"""Initialize the SSE client over an existing, ready to consume
25+
event source.
26+
27+
The event source is expected to be a binary stream and have a close()
28+
method. That would usually be something that implements
29+
io.BinaryIOBase, like an httplib or urllib3 HTTPResponse object.
30+
"""
31+
self._logger = logging.getLogger(self.__class__.__module__)
32+
self._logger.debug('Initialized SSE client from event source %s',
33+
event_source)
34+
self._event_source = event_source
35+
self._char_enc = char_enc
36+
37+
def _read(self):
38+
"""Read the incoming event source stream and yield event chunks.
39+
40+
Unfortunately it is possible for some servers to decide to break an
41+
event into multiple HTTP chunks in the response. It is thus necessary
42+
to correctly stitch together consecutive response chunks and find the
43+
SSE delimiter (empty new line) to yield full, correct event chunks."""
44+
data = b''
45+
for chunk in self._event_source:
46+
for line in chunk.splitlines(True):
47+
data += line
48+
if data.endswith((b'\r\r', b'\n\n', b'\r\n\r\n')):
49+
yield data
50+
data = b''
51+
if data:
52+
yield data
53+
54+
def events(self):
55+
for chunk in self._read():
56+
event = Event()
57+
# Split before decoding so splitlines() only uses \r and \n
58+
for line in chunk.splitlines():
59+
# Decode the line.
60+
line = line.decode(self._char_enc)
61+
62+
# Lines starting with a separator are comments and are to be
63+
# ignored.
64+
if not line.strip() or line.startswith(_FIELD_SEPARATOR):
65+
continue
66+
67+
data = line.split(_FIELD_SEPARATOR, 1)
68+
field = data[0]
69+
70+
# Ignore unknown fields.
71+
if field not in event.__dict__:
72+
self._logger.debug('Saw invalid field %s while parsing '
73+
'Server Side Event', field)
74+
continue
75+
76+
if len(data) > 1:
77+
# From the spec:
78+
# "If value starts with a single U+0020 SPACE character,
79+
# remove it from value."
80+
if data[1].startswith(' '):
81+
value = data[1][1:]
82+
else:
83+
value = data[1]
84+
else:
85+
# If no value is present after the separator,
86+
# assume an empty value.
87+
value = ''
88+
89+
# The data field may come over multiple lines and their values
90+
# are concatenated with each other.
91+
if field == 'data':
92+
event.__dict__[field] += value + '\n'
93+
else:
94+
event.__dict__[field] = value
95+
96+
# Events with no data are not dispatched.
97+
if not event.data:
98+
continue
99+
100+
# If the data field ends with a newline, remove it.
101+
if event.data.endswith('\n'):
102+
event.data = event.data[0:-1]
103+
104+
# Empty event names default to 'message'
105+
event.event = event.event or 'message'
106+
107+
# Dispatch the event
108+
self._logger.debug('Dispatching %s...', event)
109+
yield event
110+
111+
def close(self):
112+
"""Manually close the event source stream."""
113+
self._event_source.close()
114+
115+
116+
class Event(object):
117+
"""Representation of an event from the event stream."""
118+
119+
def __init__(self, id=None, event='message', data='', retry=None):
120+
self.id = id
121+
self.event = event
122+
self.data = data
123+
self.retry = retry
124+
125+
def __str__(self):
126+
s = '{0} event'.format(self.event)
127+
if self.id:
128+
s += ' #{0}'.format(self.id)
129+
if self.data:
130+
s += ', {0} byte{1}'.format(len(self.data),
131+
's' if len(self.data) else '')
132+
else:
133+
s += ', no data'
134+
if self.retry:
135+
s += ', retry in {0}ms'.format(self.retry)
136+
return s

0 commit comments

Comments
 (0)