Skip to content

Commit f02782a

Browse files
[https://nvbugs/5726066][fix] fix auto-scaling related failures (#9845)
Signed-off-by: Lizhi Zhou <[email protected]> Co-authored-by: Emma Qiao <[email protected]>
1 parent 6fe89ea commit f02782a

File tree

5 files changed

+187
-115
lines changed

5 files changed

+187
-115
lines changed

tensorrt_llm/serve/cluster_storage.py

Lines changed: 61 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,27 @@ def add_events_from_resp(self, watch_resp):
432432
self.cancel_event()
433433

434434

435+
def handle_etcd_error(return_on_error: bool | None):
436+
437+
def decorator(f):
438+
439+
async def wrap(*args, **kwargs):
440+
storage = args[0]
441+
try:
442+
return await f(*args, **kwargs)
443+
except etcd3.Etcd3Exception as e:
444+
logger.error(f"Etcd3 error in {f.__name__}: {e}")
445+
return return_on_error
446+
except ValueError:
447+
logger.error(f"Etcd client value error in {f.__name__}")
448+
storage.reconnect()
449+
return return_on_error
450+
451+
return wrap
452+
453+
return decorator
454+
455+
435456
class Etcd3ClusterStorage(ClusterStorage):
436457

437458
def __init__(self,
@@ -440,8 +461,8 @@ def __init__(self,
440461
one_single_lease: bool = False,
441462
**kwargs):
442463
cluster_uri = cluster_uri.replace("etcd://", "")
443-
host, port = cluster_uri.rsplit(":", 1)
444-
self._client = etcd3.client(host, port)
464+
self._host, self._port = cluster_uri.rsplit(":", 1)
465+
self._client = etcd3.client(self._host, self._port)
445466
self._leases = {}
446467
self._instance_lease = None
447468
self._watch_handles = {}
@@ -464,6 +485,10 @@ def _get_lease(self, key: str, ttl: int = -1) -> etcd3.Lease:
464485
def client(self):
465486
return self._client
466487

488+
def reconnect(self):
489+
logger.info(f"Reconnecting to {self._host}:{self._port}")
490+
self._client = etcd3.client(self._host, self._port)
491+
467492
async def start(self):
468493
# nothing to do
469494
...
@@ -472,78 +497,60 @@ async def stop(self):
472497
# nothing to do
473498
...
474499

500+
@handle_etcd_error(return_on_error=False)
475501
async def set(self,
476502
key: str,
477503
value: str,
478504
overwrite_if_exists: bool = False,
479505
ttl: int = -1) -> bool:
480-
try:
481-
lease = self._get_lease(key, ttl)
482-
if not overwrite_if_exists:
483-
return self.client.put_if_not_exists(key, value, lease=lease)
484-
else:
485-
self.client.put(key, value, lease=lease)
486-
except etcd3.Etcd3Exception as e:
487-
logger.error(f"Error setting key {key}: {e}")
488-
return False
489-
return True
506+
lease = self._get_lease(key, ttl)
507+
if not overwrite_if_exists:
508+
return self.client.put_if_not_exists(key, value, lease=lease)
509+
else:
510+
self.client.put(key, value, lease=lease)
511+
return True
490512

513+
@handle_etcd_error(return_on_error=None)
491514
async def get(self, key: str) -> str:
492-
try:
493-
data, meta = self.client.get(key)
494-
return data.decode('utf-8') if data else None
495-
except etcd3.Etcd3Exception as e:
496-
logger.error(f"Error getting key {key}: {e}")
497-
return None
515+
data, meta = self.client.get(key)
516+
return data.decode('utf-8') if data else None
498517

518+
@handle_etcd_error(return_on_error=False)
499519
async def delete(self, key: str) -> bool:
500-
try:
501-
self.client.delete(key)
502-
except etcd3.Etcd3Exception as e:
503-
logger.error(f"Error deleting key {key}: {e}")
504-
return False
505-
return True
520+
self.client.delete(key)
506521

522+
@handle_etcd_error(return_on_error=False)
507523
async def expire(self, key: str, ttl: int) -> bool:
508524
if ttl <= 0:
509-
raise ValueError(f"TTL must be greater than 0, got {ttl}")
510-
try:
511-
lease = self._get_lease(key, ttl)
512-
# TTL will be ignored since it can only be set when creating a lease
513-
next(self.client.refresh_lease(lease_id=lease.id), None)
514-
except etcd3.Etcd3Exception as e:
515-
logger.error(f"Error refreshing lease {key}: {e}")
525+
logger.error(f"TTL must be greater than 0, got {ttl}")
516526
return False
527+
lease = self._get_lease(key, ttl)
528+
assert lease is not None, "Lease must be created"
529+
# TTL will be ignored since it can only be set when creating a lease
530+
next(self.client.refresh_lease(lease_id=lease.id), None)
517531
return True
518532

533+
@handle_etcd_error(return_on_error={})
519534
async def get_prefix(self,
520535
key_prefix: str,
521536
keys_only: bool = False) -> Dict[str, str]:
522-
try:
523-
resp = self.client.get_prefix(key_prefix)
524-
return {
525-
metadata.key.decode("utf-8"):
526-
"" if keys_only else v.decode("utf-8")
527-
for v, metadata in resp
528-
}
529-
except etcd3.Etcd3Exception as e:
530-
logger.error(f"Error getting keys {key_prefix}: {e}")
531-
return {}
537+
resp = self.client.get_prefix(key_prefix)
538+
return {
539+
metadata.key.decode("utf-8"): "" if keys_only else v.decode("utf-8")
540+
for v, metadata in resp
541+
}
532542

543+
@handle_etcd_error(return_on_error=None)
533544
async def watch(self, key_prefix: str) -> WatchEventQueue:
534-
try:
535-
if key_prefix in self._watch_handles:
536-
return self._watch_handles[key_prefix]
537-
watch_handle = Etcd3WatchEventQueue(key_prefix=key_prefix)
538-
watch_id = self.client.add_watch_prefix_callback(
539-
key_prefix, watch_handle.add_events_from_resp)
540-
watch_handle.set_cancel_event(
541-
lambda: self.client.cancel_watch(watch_id))
542-
self._watch_handles[key_prefix] = watch_handle
543-
return watch_handle
544-
except etcd3.Etcd3Exception as e:
545-
logger.error(f"Error watching key {key_prefix}: {e}")
546-
return None
545+
if key_prefix in self._watch_handles:
546+
return self._watch_handles[key_prefix]
547+
watch_handle = Etcd3WatchEventQueue(key_prefix=key_prefix)
548+
watch_id = self.client.add_watch_prefix_callback(
549+
key_prefix, watch_handle.add_events_from_resp)
550+
watch_handle.set_cancel_event(
551+
lambda: self.client.cancel_watch(watch_id))
552+
self._watch_handles[key_prefix] = watch_handle
553+
return watch_handle
547554

548555
async def unwatch(self, key_prefix: str) -> None:
549556
handle = self._watch_handles.pop(key_prefix)

tensorrt_llm/serve/disagg_auto_scaling.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ async def watch_workers(
117117
workers = []
118118
self._watch_handle = await self._cluster_storage.watch(
119119
self.worker_key_prefix)
120+
121+
assert self._watch_handle is not None, "failed to watch workers"
122+
120123
if get_existing_first:
121124
# There is a tiny gap between getting existing workers and watching the key,
122125
# which may cause we missing some workers registered in between.

0 commit comments

Comments
 (0)