11import abc
22import asyncio
3- import logging
43import time
54from dataclasses import dataclass
65from enum import IntEnum
76from functools import wraps
8- from typing import Dict , List , Optional
7+ from typing import Callable , Dict , List , Optional
98
109import aiohttp
10+ import etcd3
1111from fastapi import FastAPI
1212from fastapi .responses import JSONResponse
1313from pydantic import BaseModel
1414
15- logger = logging . getLogger ( 'uvicorn.error' )
15+ from tensorrt_llm . logger import logger
1616
1717
1818class StorageItem (BaseModel ):
@@ -91,7 +91,7 @@ async def delete(self, key: str) -> bool:
9191 async def watch (self , key_prefix : str ) -> WatchEventQueue :
9292 ...
9393
94- # unwatch the key prefix, if the key prefix is not in the watch list, raise an error
94+ # unwatch the key prefix, if the key prefix is not in the watch list, raise a KeyError
9595 async def unwatch (self , key_prefix : str ) -> None :
9696 ...
9797
@@ -106,12 +106,16 @@ async def get_prefix(self,
106106def create_cluster_storage (cluster_uri , cluster_name , ** kwargs ):
107107 if cluster_uri .startswith ("http" ):
108108 return HttpClusterStorageServer (cluster_uri , cluster_name , ** kwargs )
109+ elif cluster_uri .startswith ("etcd" ):
110+ return Etcd3ClusterStorage (cluster_uri , cluster_name , ** kwargs )
109111 raise ValueError (f"Invalid cluster storage URI: { cluster_uri } " )
110112
111113
112- def create_cluster_storage_client (cluster_uri , cluster_name ):
114+ def create_cluster_storage_client (cluster_uri , cluster_name , ** kwargs ):
113115 if cluster_uri .startswith ("http" ):
114- return HttpClusterStorageClient (cluster_uri , cluster_name )
116+ return HttpClusterStorageClient (cluster_uri , cluster_name , ** kwargs )
117+ elif cluster_uri .startswith ("etcd" ):
118+ return Etcd3ClusterStorage (cluster_uri , cluster_name , ** kwargs )
115119 raise ValueError (f"Invalid cluster storage URI: { cluster_uri } " )
116120
117121
@@ -241,7 +245,7 @@ async def unwatch(self, key_prefix: str) -> None:
241245 if key_prefix in self ._watch_handles :
242246 self ._watch_handles .pop (key_prefix )
243247 else :
244- raise ValueError (
248+ raise KeyError (
245249 f"Key prefix { key_prefix } not in watch list, { self ._watch_handles .keys ()} "
246250 )
247251
@@ -377,3 +381,159 @@ async def watch(self, key_prefix: str) -> WatchEventQueue:
377381 async def unwatch (self , key_prefix : str ) -> None :
378382 raise NotImplementedError (
379383 "Unwatch functionality not implemented for HTTP client" )
384+
385+
386+ class Etcd3WatchEventQueue (WatchEventQueue ):
387+
388+ def __init__ (self ,
389+ key_prefix : str ,
390+ cancel_event : Callable [[], None ] = None ):
391+ self .key_prefix = key_prefix
392+ self ._cancel_event = cancel_event
393+ self .events = asyncio .Queue ()
394+
395+ def cancel_event (self ):
396+ if self ._cancel_event :
397+ self ._cancel_event ()
398+
399+ def set_cancel_event (self , cancel_event : Callable [[], None ]):
400+ self ._cancel_event = cancel_event
401+
402+ def __del__ (self ):
403+ self .cancel_event ()
404+
405+ def add_event (self , watch_resp ):
406+ try :
407+ for event in watch_resp .events :
408+ # Event type is not in public interface of etcd3
409+ event_type = WatchEventType .SET if "Put" in event .__class__ .__name__ else WatchEventType .DELETE
410+ self .events .put_nowait (
411+ WatchEvent (
412+ storage_item = StorageItem (
413+ key = event .key .decode ("utf-8" ),
414+ value = event .value .decode ("utf-8" )),
415+ event_type = event_type ,
416+ ))
417+ if self .events ._loop :
418+ self .events ._loop ._write_to_self ()
419+ except Exception as e :
420+ logger .error (f"Error adding event: { e } " )
421+ self .cancel_event ()
422+
423+
424+ class Etcd3ClusterStorage (ClusterStorage ):
425+
426+ def __init__ (self ,
427+ cluster_uri : str ,
428+ cluster_name : str ,
429+ one_single_lease : bool = False ):
430+ cluster_uri = cluster_uri .replace ("etcd://" , "" )
431+ host , port = cluster_uri .rsplit (":" , 1 )
432+ self ._client = etcd3 .client (host , port )
433+ self ._leases = {}
434+ self ._instance_lease = None
435+ self ._watch_handles = {}
436+ self ._one_single_lease = one_single_lease
437+
438+ def __del__ (self ):
439+ self ._watch_handles .clear ()
440+ self ._client .close ()
441+
442+ def _get_lease (self , key : str , ttl : int = - 1 ) -> etcd3 .Lease :
443+ if ttl <= 0 :
444+ return None
445+ if self ._one_single_lease :
446+ return self ._instance_lease
447+ if key not in self ._leases :
448+ self ._leases [key ] = self .client .lease (ttl )
449+ return self ._leases [key ]
450+
451+ @property
452+ def client (self ):
453+ return self ._client
454+
455+ async def start (self ):
456+ # nothing to do
457+ ...
458+
459+ async def stop (self ):
460+ # nothing to do
461+ ...
462+
463+ async def set (self ,
464+ key : str ,
465+ value : str ,
466+ overwrite_if_exists : bool = False ,
467+ ttl : int = - 1 ) -> bool :
468+ try :
469+ lease = self ._get_lease (key , ttl )
470+ if not overwrite_if_exists :
471+ return self .client .put_if_not_exists (key , value , lease = lease )
472+ else :
473+ self .client .put (key , value , lease = lease )
474+ except etcd3 .Etcd3Exception as e :
475+ logger .error (f"Error setting key { key } : { e } " )
476+ return False
477+ return True
478+
479+ async def get (self , key : str ) -> str :
480+ try :
481+ data , meta = self .client .get (key )
482+ return data .decode ('utf-8' ) if data else None
483+ except etcd3 .Etcd3Exception as e :
484+ logger .error (f"Error getting key { key } : { e } " )
485+ return None
486+
487+ async def delete (self , key : str ) -> bool :
488+ try :
489+ self .client .delete (key )
490+ except etcd3 .Etcd3Exception as e :
491+ logger .error (f"Error deleting key { key } : { e } " )
492+ return False
493+ return True
494+
495+ async def expire (self , key : str , ttl : int ) -> bool :
496+ if ttl <= 0 :
497+ raise ValueError (f"TTL must be greater than 0, got { ttl } " )
498+ try :
499+ lease = self ._get_lease (key , ttl )
500+ # TTL will be ignored since it can only be set when creating a lease
501+ self .client .refresh_lease (lease_id = lease .id )
502+ except etcd3 .Etcd3Exception as e :
503+ logger .error (f"Error refreshing lease { key } : { e } " )
504+ return False
505+ return True
506+
507+ async def get_prefix (self ,
508+ key_prefix : str ,
509+ keys_only : bool = False ) -> Dict [str , str ]:
510+ try :
511+ resp = self .client .get_prefix (key_prefix , keys_only = keys_only )
512+ return {
513+ metadata .key .decode ("utf-8" ):
514+ "" if keys_only else v .decode ("utf-8" )
515+ for v , metadata in resp
516+ }
517+ except etcd3 .Etcd3Exception as e :
518+ logger .error (f"Error getting keys { key_prefix } : { e } " )
519+ return {}
520+
521+ async def watch (self , key_prefix : str ) -> WatchEventQueue :
522+ try :
523+ if key_prefix in self ._watch_handles :
524+ return self ._watch_handles [key_prefix ]
525+ watch_handle = Etcd3WatchEventQueue (key_prefix = key_prefix )
526+ watch_id = self .client .add_watch_prefix_callback (
527+ key_prefix , watch_handle .add_event )
528+ watch_handle .set_cancel_event (
529+ lambda : self .client .cancel_watch (watch_id ))
530+ self ._watch_handles [key_prefix ] = watch_handle
531+ return watch_handle
532+ except etcd3 .Etcd3Exception as e :
533+ logger .error (f"Error watching key { key_prefix } : { e } " )
534+ return None
535+
536+ async def unwatch (self , key_prefix : str ) -> None :
537+ handle = self ._watch_handles .pop (key_prefix )
538+ if handle :
539+ handle .cancel_event ()
0 commit comments