diff --git a/changes/5289.fix.md b/changes/5289.fix.md new file mode 100644 index 00000000000..edf7553a991 --- /dev/null +++ b/changes/5289.fix.md @@ -0,0 +1 @@ +Make AsyncEtcd an explicit async context manager to leverage etcd-client-py update (0.5.1) that ensures graceful tokio runtime shutdown diff --git a/python.lock b/python.lock index 7656530e3a7..3cdafa41dc7 100644 --- a/python.lock +++ b/python.lock @@ -46,7 +46,7 @@ // "coloredlogs~=15.0", // "cryptography>=44.0.2", // "dataclasses-json~=0.5.7", -// "etcd-client-py~=0.4.1", +// "etcd-client-py~=0.5.1", // "faker~=24.7.1", // "graphene-federation~=3.2.0", // "graphene~=3.3.0", @@ -103,7 +103,7 @@ // "tabulate~=0.8.9", // "temporenc~=0.1.0", // "tenacity>=9.0", -// "testcontainers[minio,postgres,redis]~=4.8.1", +// "testcontainers[minio,postgres,redis]~=4.13.3", // "textual~=0.79.1", // "tomli-w~=1.2.0", // "tomli~=2.0.1", @@ -2063,29 +2063,37 @@ "artifacts": [ { "algorithm": "sha256", - "hash": "e10e972639606206892aef94953fb023e4bb1a4fa68143396e49f4d0da31a13a", - "url": "https://files.pythonhosted.org/packages/46/32/6b0a21ddd58c6da6dd1c2fedd9c23fa717cc0ca204a6ca5329c8123374b9/etcd_client_py-0.4.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" + "hash": "63d245dcecdf845aa76b4da4f73d00919f07eb35788e683ee4f809e188c53375", + "url": "https://files.pythonhosted.org/packages/3f/6c/e6b2cf73127561e9c9bae8a41c017be346deeb54d0f8128109c2581beae6/etcd_client_py-0.5.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, { "algorithm": "sha256", - "hash": "26845fbf2bbdee9283df60f58938e8f8ea111e3cce4bb9737154ccd526a01813", - "url": "https://files.pythonhosted.org/packages/25/3c/ec281bc41637d163471f491dc04367799cf6e75e86526ff486b8a997d44c/etcd_client_py-0.4.1-cp313-cp313-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl" + "hash": "c8bc311f225147e3260c6416d4c66146dd291e5af5c6b8120d0b410cce89a8d1", + "url": "https://files.pythonhosted.org/packages/a2/d1/6aad26b2f0581db64650f640c0d7968bbe25cc6a2fc35f7d66d4c0347f6b/etcd_client_py-0.5.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, { "algorithm": "sha256", - "hash": "3909bd4857274e2ffa37aabc6d4a7e49dc2f903ef657dbf141555887b707129e", - "url": "https://files.pythonhosted.org/packages/c7/75/4034fbbdcb9b670e7add878a208484b9aceafea7dc55c4964fd5e0eb67a1/etcd_client_py-0.4.1.tar.gz" + "hash": "dee52d29cd2d441ecebbc84b80c3c08370b0b5859dc5b165bf9133bd209e2565", + "url": "https://files.pythonhosted.org/packages/ab/92/2f21e778b884729d412613713d33eda98c644fbfb03bca4bb0ded94298ec/etcd_client_py-0.5.1-cp313-cp313-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl" }, { "algorithm": "sha256", - "hash": "cdbf59cc758692595ae2f4a592bf75108ddb5d61fa7a5440ba089bbbe766dbe9", - "url": "https://files.pythonhosted.org/packages/fa/91/c67993ae03b54db1f4f308a7cabdf765371a2cbdc2b7e94fa3ea28065c37/etcd_client_py-0.4.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" + "hash": "7809b1adfa58c0e84cbaac166c7e76ce5987558426acaa8651075dd2315a9212", + "url": "https://files.pythonhosted.org/packages/bb/d2/c79d349cd668a24a3647b17db2396c59393ba6b368279abbe1169ba39bc7/etcd_client_py-0.5.1.tar.gz" } ], "project_name": "etcd-client-py", - "requires_dists": [], + "requires_dists": [ + "maturin<2.0,>=1.11; extra == \"dev\"", + "mypy>=1.19.1; extra == \"dev\"", + "pytest-asyncio<2,>=1.3.0; extra == \"test\"", + "pytest<10,>=9.0.2; extra == \"test\"", + "ruff>=0.14.10; extra == \"dev\"", + "testcontainers<5,>=4.13.3; extra == \"test\"", + "trafaret<3,>=2.1; extra == \"test\"" + ], "requires_python": ">=3.10", - "version": "0.4.1" + "version": "0.5.1" }, { "artifacts": [ @@ -6435,13 +6443,13 @@ "artifacts": [ { "algorithm": "sha256", - "hash": "9e19af077cd96e1957c13ee466f1f32905bc6c5bc1bc98643eb18be1a989bfb0", - "url": "https://files.pythonhosted.org/packages/80/77/5ac0dff2903a033d83d971fd85957356abdb66a327f3589df2b3d1a586b4/testcontainers-4.8.2-py3-none-any.whl" + "hash": "063278c4805ffa6dd85e56648a9da3036939e6c0ac1001e851c9276b19b05970", + "url": "https://files.pythonhosted.org/packages/73/27/c2f24b19dafa197c514abe70eda69bc031c5152c6b1f1e5b20099e2ceedd/testcontainers-4.13.3-py3-none-any.whl" }, { "algorithm": "sha256", - "hash": "dd4a6a2ea09e3c3ecd39e180b6548105929d0bb78d665ce9919cb3f8c98f9853", - "url": "https://files.pythonhosted.org/packages/1f/72/c58d84f5704c6caadd9f803a3adad5ab54ac65328c02d13295f40860cf33/testcontainers-4.8.2.tar.gz" + "hash": "9d82a7052c9a53c58b69e1dc31da8e7a715e8b3ec1c4df5027561b47e2efe646", + "url": "https://files.pythonhosted.org/packages/fc/b3/c272537f3ea2f312555efeb86398cc382cd07b740d5f3c730918c36e64e1/testcontainers-4.13.3.tar.gz" } ], "project_name": "testcontainers", @@ -6451,29 +6459,30 @@ "bcrypt; extra == \"registry\"", "boto3; extra == \"aws\" or extra == \"localstack\"", "cassandra-driver==3.29.1; extra == \"scylla\"", - "chromadb-client; extra == \"chroma\"", - "clickhouse-driver; extra == \"clickhouse\"", + "chromadb-client<2.0.0,>=1.0.0; extra == \"chroma\"", "cryptography; extra == \"mailpit\" or extra == \"sftp\"", "docker", "google-cloud-datastore>=2; extra == \"google\"", "google-cloud-pubsub>=2; extra == \"google\"", "httpx; extra == \"aws\" or extra == \"generic\" or extra == \"test-module-import\"", - "ibm_db_sa; extra == \"db2\"", + "ibm_db_sa; (platform_machine != \"aarch64\" and platform_machine != \"arm64\") and extra == \"db2\"", "influxdb-client; extra == \"influxdb\"", "influxdb; extra == \"influxdb\"", "kubernetes; extra == \"k3s\"", "minio; extra == \"minio\"", "nats-py; extra == \"nats\"", "neo4j; extra == \"neo4j\"", - "opensearch-py; extra == \"opensearch\"", - "oracledb; extra == \"oracle\" or extra == \"oracle-free\"", + "openfga-sdk; python_version >= \"3.10\" and extra == \"openfga\"", + "opensearch-py; python_version < \"4.0\" and extra == \"opensearch\"", + "oracledb>=3.4.1; extra == \"oracle\" or extra == \"oracle-free\"", "pika; extra == \"rabbitmq\"", "pymongo; extra == \"mongodb\"", - "pymssql; extra == \"mssql\"", + "pymssql>=2.3.9; (platform_machine != \"arm64\" or python_version >= \"3.10\") and extra == \"mssql\"", "pymysql[rsa]; extra == \"mysql\"", "python-arango<8.0,>=7.8; extra == \"arangodb\"", + "python-dotenv", "python-keycloak; extra == \"keycloak\"", - "pyyaml; extra == \"k3s\"", + "pyyaml>=6.0.3; extra == \"k3s\"", "qdrant-client; extra == \"qdrant\"", "redis; extra == \"generic\" or extra == \"redis\"", "selenium; extra == \"selenium\"", @@ -6484,8 +6493,8 @@ "weaviate-client<5.0.0,>=4.5.4; extra == \"weaviate\"", "wrapt" ], - "requires_python": "<4.0,>=3.9", - "version": "4.8.2" + "requires_python": ">=3.9.2", + "version": "4.13.3" }, { "artifacts": [ @@ -7717,7 +7726,7 @@ "coloredlogs~=15.0", "cryptography>=44.0.2", "dataclasses-json~=0.5.7", - "etcd-client-py~=0.4.1", + "etcd-client-py~=0.5.1", "faker~=24.7.1", "graphene-federation~=3.2.0", "graphene~=3.3.0", @@ -7774,7 +7783,7 @@ "tabulate~=0.8.9", "temporenc~=0.1.0", "tenacity>=9.0", - "testcontainers[minio,postgres,redis]~=4.8.1", + "testcontainers[minio,postgres,redis]~=4.13.3", "textual~=0.79.1", "tomli-w~=1.2.0", "tomli~=2.0.1", diff --git a/requirements.txt b/requirements.txt index 8db242b68f0..2b34939cb94 100644 --- a/requirements.txt +++ b/requirements.txt @@ -106,7 +106,7 @@ pytest>=8.3.3 pytest-aiohttp~=1.1.0 pytest-dependency>=0.6.0 pytest-mock>=3.14.0 -testcontainers[postgres,redis,minio]~=4.8.1 +testcontainers[postgres,redis,minio]~=4.13.3 # type stubs types-six @@ -129,4 +129,4 @@ types-tqdm backend.ai-krunner-alpine==5.4.0 backend.ai-krunner-static-gnu==4.4.0 -etcd-client-py~=0.4.1 +etcd-client-py~=0.5.1 diff --git a/scripts/app-proxy/migrate-health-check-configuration.py b/scripts/app-proxy/migrate-health-check-configuration.py index 33a41432912..1205c6b8b04 100644 --- a/scripts/app-proxy/migrate-health-check-configuration.py +++ b/scripts/app-proxy/migrate-health-check-configuration.py @@ -92,8 +92,8 @@ async def update_appproxy_endpoint_entity( async def main(get_bootstrap_config_coro: Coroutine[None, None, BootstrapConfig]) -> None: config: BootstrapConfig = await get_bootstrap_config_coro - etcd = AsyncEtcd.initialize(config.etcd.to_dataclass()) - raw_volumes_config = await etcd.get_prefix("volumes") + async with AsyncEtcd.create_from_config(config.etcd.to_dataclass()) as etcd: + raw_volumes_config = await etcd.get_prefix("volumes") storage_manager = StorageSessionManager(VolumesConfig(**raw_volumes_config)) db_username = config.db.user diff --git a/src/ai/backend/agent/dependencies/bootstrap/etcd.py b/src/ai/backend/agent/dependencies/bootstrap/etcd.py index 01a5e9ed3cd..38973da96d6 100644 --- a/src/ai/backend/agent/dependencies/bootstrap/etcd.py +++ b/src/ai/backend/agent/dependencies/bootstrap/etcd.py @@ -52,17 +52,13 @@ async def provide(self, setup_input: AgentUnifiedConfig) -> AsyncIterator[AsyncE # Convert config to dataclass format and initialize etcd etcd_config_data = setup_input.etcd.to_dataclass() - etcd = AsyncEtcd( + async with AsyncEtcd( [addr.to_legacy() for addr in etcd_config_data.addrs], setup_input.etcd.namespace, scope_prefix_map, credentials=etcd_credentials, - ) - - try: + ) as etcd: yield etcd - finally: - await etcd.close() def gen_health_checkers(self, resource: AsyncEtcd) -> ServiceHealthChecker: """ diff --git a/src/ai/backend/agent/server.py b/src/ai/backend/agent/server.py index b987dec374f..39fac6a2b89 100644 --- a/src/ai/backend/agent/server.py +++ b/src/ai/backend/agent/server.py @@ -1402,16 +1402,13 @@ async def etcd_ctx(local_config: AgentUnifiedConfig) -> AsyncGenerator[AsyncEtcd ConfigScopes.NODE: f"nodes/agents/{local_config.agent.defaulted_id}", } etcd_config_data = local_config.etcd.to_dataclass() - etcd = AsyncEtcd( + async with AsyncEtcd( [addr.to_legacy() for addr in etcd_config_data.addrs], local_config.etcd.namespace, scope_prefix_map, credentials=etcd_credentials, - ) - try: + ) as etcd: yield etcd - finally: - await etcd.close() async def prepare_krunner_volumes(local_config: AgentUnifiedConfig) -> None: diff --git a/src/ai/backend/agent/watcher/__init__.py b/src/ai/backend/agent/watcher/__init__.py index 11c5fe1fae0..e7918683f63 100644 --- a/src/ai/backend/agent/watcher/__init__.py +++ b/src/ai/backend/agent/watcher/__init__.py @@ -291,52 +291,52 @@ async def watcher_server( scope_prefix_map = { ConfigScopes.GLOBAL: "", } - etcd = AsyncEtcd( + async with AsyncEtcd( app["config"]["etcd"]["addr"], app["config"]["etcd"]["namespace"], scope_prefix_map=scope_prefix_map, credentials=etcd_credentials, - ) - app["config_server"] = etcd - - token = await etcd.get("config/watcher/token") - if token is None: - token = "insecure" - log.debug("watcher authentication token: {}", token) - app["token"] = token - - app.middlewares.append(auth_middleware) - app.on_shutdown.append(shutdown_app) - app.on_startup.append(init_app) - app.on_response_prepare.append(prepare_hook) - ssl_ctx = None - if app["config"]["watcher"]["ssl-enabled"]: - ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - ssl_ctx.load_cert_chain( - str(app["config"]["watcher"]["ssl-cert"]), - str(app["config"]["watcher"]["ssl-privkey"]), + ) as etcd: + app["config_server"] = etcd + + token = await etcd.get("config/watcher/token") + if token is None: + token = "insecure" + log.debug("watcher authentication token: {}", token) + app["token"] = token + + app.middlewares.append(auth_middleware) + app.on_shutdown.append(shutdown_app) + app.on_startup.append(init_app) + app.on_response_prepare.append(prepare_hook) + ssl_ctx = None + if app["config"]["watcher"]["ssl-enabled"]: + ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_ctx.load_cert_chain( + str(app["config"]["watcher"]["ssl-cert"]), + str(app["config"]["watcher"]["ssl-privkey"]), + ) + runner = web.AppRunner(app) + await runner.setup() + watcher_addr = app["config"]["watcher"]["service-addr"] + site = web.TCPSite( + runner, + str(watcher_addr.host), + watcher_addr.port, + backlog=5, + reuse_port=True, + ssl_context=ssl_ctx, ) - runner = web.AppRunner(app) - await runner.setup() - watcher_addr = app["config"]["watcher"]["service-addr"] - site = web.TCPSite( - runner, - str(watcher_addr.host), - watcher_addr.port, - backlog=5, - reuse_port=True, - ssl_context=ssl_ctx, - ) - await site.start() - log.info("started at {}", watcher_addr) - try: - stop_sig = yield - finally: - log.info("shutting down...") - if stop_sig == signal.SIGALRM and shutdown_enabled: - log.warning("shutting down the agent node!") - subprocess.run(["shutdown", "-h", "now"]) - await runner.cleanup() + await site.start() + log.info("started at {}", watcher_addr) + try: + stop_sig = yield + finally: + log.info("shutting down...") + if stop_sig == signal.SIGALRM and shutdown_enabled: + log.warning("shutting down the agent node!") + subprocess.run(["shutdown", "-h", "now"]) + await runner.cleanup() @click.command() diff --git a/src/ai/backend/cli/__main__.py b/src/ai/backend/cli/__main__.py index 1177436a6ed..6ef2ef66b48 100644 --- a/src/ai/backend/cli/__main__.py +++ b/src/ai/backend/cli/__main__.py @@ -1,5 +1,4 @@ import shutil -import time from .loader import load_entry_points @@ -8,8 +7,4 @@ if __name__ == "__main__": # Execute right away if the module is directly called from CLI. - try: - main(max_content_width=shutil.get_terminal_size().columns - 2) - finally: - # Workaround for tokio/pyo3-async-runtimes shutdown race (BA-1976) - time.sleep(0.1) + main(max_content_width=shutil.get_terminal_size().columns - 2) diff --git a/src/ai/backend/common/config.py b/src/ai/backend/common/config.py index aa42bafd91e..798196cd081 100644 --- a/src/ai/backend/common/config.py +++ b/src/ai/backend/common/config.py @@ -403,8 +403,8 @@ def read_from_file( async def read_from_etcd( etcd_config: Mapping[str, Any], scope_prefix_map: Mapping[ConfigScopes, str] ) -> Optional[dict[str, Any]]: - etcd = AsyncEtcd(etcd_config["addr"], etcd_config["namespace"], scope_prefix_map) - raw_value = await etcd.get("daemon/config") + async with AsyncEtcd(etcd_config["addr"], etcd_config["namespace"], scope_prefix_map) as etcd: + raw_value = await etcd.get("daemon/config") if raw_value is None: return None config: dict[str, Any] diff --git a/src/ai/backend/common/etcd.py b/src/ai/backend/common/etcd.py index 60f93f9e78a..41801bf6e70 100644 --- a/src/ai/backend/common/etcd.py +++ b/src/ai/backend/common/etcd.py @@ -23,6 +23,7 @@ MutableMapping, Sequence, ) +from types import TracebackType from typing import ( Optional, ParamSpec, @@ -293,7 +294,7 @@ def __init__( ) @classmethod - def initialize(cls, etcd_config: EtcdConfigData) -> Self: + def create_from_config(cls, etcd_config: EtcdConfigData) -> Self: etcd_addrs = [addr.to_legacy() for addr in etcd_config.addrs] namespace = etcd_config.namespace etcd_user = etcd_config.user @@ -315,8 +316,23 @@ def initialize(cls, etcd_config: EtcdConfigData) -> Self: return cls(etcd_addrs, namespace, scope_prefix_map, credentials=credentials) - async def close(self): - pass # for backward compatibility + async def open(self) -> None: + await self.etcd.__aenter__() + + async def close(self) -> None: + await self.etcd.__aexit__(None, None, None) + + async def __aenter__(self) -> Self: + await self.etcd.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return await self.etcd.__aexit__(exc_type, exc_val, exc_tb) async def ping(self) -> None: """ diff --git a/src/ai/backend/install/context.py b/src/ai/backend/install/context.py index 17e5b6fa579..af755f95d28 100644 --- a/src/ai/backend/install/context.py +++ b/src/ai/backend/install/context.py @@ -244,16 +244,13 @@ async def etcd_ctx(self) -> AsyncIterator[AsyncEtcd]: "user": halfstack.etcd_user, "password": halfstack.etcd_password, } - etcd = AsyncEtcd( + async with AsyncEtcd( [addr.face for addr in self.install_info.halfstack_config.etcd_addr], "local", scope_prefix_map, credentials=creds, - ) - try: + ) as etcd: yield etcd - finally: - await etcd.close() async def etcd_put_json(self, key: str, value: Any) -> None: async with self.etcd_ctx() as etcd: diff --git a/src/ai/backend/manager/cli/context.py b/src/ai/backend/manager/cli/context.py index 5323e67f860..24d1bd3363c 100644 --- a/src/ai/backend/manager/cli/context.py +++ b/src/ai/backend/manager/cli/context.py @@ -91,16 +91,13 @@ async def etcd_ctx(cli_ctx: CLIContext) -> AsyncIterator[AsyncEtcd]: ConfigScopes.GLOBAL: "", # TODO: provide a way to specify other scope prefixes } - etcd = AsyncEtcd( + async with AsyncEtcd( [addr.to_legacy() for addr in etcd_config_data.addrs], etcd_config_data.namespace, scope_prefix_map, credentials=creds, - ) - try: + ) as etcd: yield etcd - finally: - await etcd.close() @contextlib.asynccontextmanager @@ -113,15 +110,11 @@ async def config_ctx(cli_ctx: CLIContext) -> AsyncIterator[ManagerUnifiedConfig] bootstrap_config = await cli_ctx.get_bootstrap_config() etcd_config_data = bootstrap_config.etcd.to_dataclass() - etcd = AsyncEtcd.initialize(etcd_config_data) - etcd_loader = LegacyEtcdLoader(etcd) - redis_config = await etcd_loader.load() - unified_config = ManagerUnifiedConfig(**redis_config) - - try: - yield unified_config - finally: - await etcd_loader.close() + async with AsyncEtcd.create_from_config(etcd_config_data) as etcd: + etcd_loader = LegacyEtcdLoader(etcd) + redis_config = await etcd_loader.load() + unified_config = ManagerUnifiedConfig(**redis_config) + yield unified_config @contextlib.asynccontextmanager @@ -145,11 +138,11 @@ async def redis_ctx(cli_ctx: CLIContext) -> AsyncIterator[RedisConnectionSet]: bootstrap_config = await cli_ctx.get_bootstrap_config() etcd_config_data = bootstrap_config.etcd.to_dataclass() - etcd = AsyncEtcd.initialize(etcd_config_data) - loader = LegacyEtcdLoader(etcd, config_prefix="config/redis") - raw_redis_config = await loader.load() - redis_config = RedisConfig(**raw_redis_config) - redis_profile_target = redis_config.to_redis_profile_target() + async with AsyncEtcd.create_from_config(etcd_config_data) as etcd: + loader = LegacyEtcdLoader(etcd, config_prefix="config/redis") + raw_redis_config = await loader.load() + redis_config = RedisConfig(**raw_redis_config) + redis_profile_target = redis_config.to_redis_profile_target() valkey_live_client = await ValkeyLiveClient.create( redis_profile_target.profile_target(RedisRole.LIVE).to_valkey_target(), diff --git a/src/ai/backend/manager/config/loader/legacy_etcd_loader.py b/src/ai/backend/manager/config/loader/legacy_etcd_loader.py index 7b0f66e2b56..8f30a5106ce 100644 --- a/src/ai/backend/manager/config/loader/legacy_etcd_loader.py +++ b/src/ai/backend/manager/config/loader/legacy_etcd_loader.py @@ -22,6 +22,13 @@ class LegacyEtcdLoader(AbstractConfigLoader): + """ + A configuration loader from an AsyncEtcd instance. + + The responsibility to keep the etcd client's lifecycle longer than the loader + is on the user of this class. + """ + _etcd: AsyncEtcd _config_prefix: str = "config" @@ -35,9 +42,6 @@ def __init__(self, etcd: AsyncEtcd, config_prefix: Optional[str] = None) -> None async def load(self) -> Mapping[str, Any]: return await self._etcd.get_prefix(self._config_prefix) - async def close(self) -> None: - await self._etcd.close() - def __hash__(self) -> int: # When used as a key in dicts, we don't care our contents. # Just treat it like an opaque object. diff --git a/src/ai/backend/manager/dependencies/bootstrap/etcd.py b/src/ai/backend/manager/dependencies/bootstrap/etcd.py index 2c521183a96..30888275ceb 100644 --- a/src/ai/backend/manager/dependencies/bootstrap/etcd.py +++ b/src/ai/backend/manager/dependencies/bootstrap/etcd.py @@ -28,11 +28,8 @@ async def provide(self, setup_input: BootstrapConfig) -> AsyncIterator[AsyncEtcd Yields: Initialized AsyncEtcd client """ - etcd = AsyncEtcd.initialize(setup_input.etcd.to_dataclass()) - try: + async with AsyncEtcd.create_from_config(setup_input.etcd.to_dataclass()) as etcd: yield etcd - finally: - await etcd.close() def gen_health_checkers(self, resource: AsyncEtcd) -> ServiceHealthChecker: """ diff --git a/src/ai/backend/manager/server.py b/src/ai/backend/manager/server.py index 4cf0f6b9848..0215e01db27 100644 --- a/src/ai/backend/manager/server.py +++ b/src/ai/backend/manager/server.py @@ -453,11 +453,9 @@ async def exception_middleware( @asynccontextmanager async def etcd_ctx(root_ctx: RootContext, etcd_config: EtcdConfigData) -> AsyncIterator[None]: - root_ctx.etcd = AsyncEtcd.initialize(etcd_config) - try: + async with AsyncEtcd.create_from_config(etcd_config) as etcd: + root_ctx.etcd = etcd yield - finally: - await root_ctx.etcd.close() @asynccontextmanager diff --git a/src/ai/backend/storage/dependencies/infrastructure/etcd.py b/src/ai/backend/storage/dependencies/infrastructure/etcd.py index 5e2155d3835..8c55b5e39c3 100644 --- a/src/ai/backend/storage/dependencies/infrastructure/etcd.py +++ b/src/ai/backend/storage/dependencies/infrastructure/etcd.py @@ -21,11 +21,8 @@ def stage_name(self) -> str: @asynccontextmanager async def provide(self, setup_input: StorageProxyUnifiedConfig) -> AsyncIterator[AsyncEtcd]: """Create and provide an etcd client.""" - etcd = make_etcd(setup_input) - try: + async with make_etcd(setup_input) as etcd: yield etcd - finally: - await etcd.close() def gen_health_checkers(self, resource: AsyncEtcd) -> ServiceHealthChecker: """ diff --git a/src/ai/backend/storage/server.py b/src/ai/backend/storage/server.py index 55036df6711..4445b8f1b55 100644 --- a/src/ai/backend/storage/server.py +++ b/src/ai/backend/storage/server.py @@ -171,11 +171,8 @@ async def aiomonitor_ctx( @asynccontextmanager async def etcd_ctx(local_config: StorageProxyUnifiedConfig) -> AsyncGenerator[AsyncEtcd]: - etcd = make_etcd(local_config) - try: + async with make_etcd(local_config) as etcd: yield etcd - finally: - await etcd.close() @asynccontextmanager diff --git a/tests/component/manager/api/test_config.py b/tests/component/manager/api/test_config.py index 72f71681fb4..4c43f2227e1 100644 --- a/tests/component/manager/api/test_config.py +++ b/tests/component/manager/api/test_config.py @@ -14,15 +14,15 @@ async def test_register_myself(bootstrap_config, mocker): mocked_get_instance_id = AsyncMock(return_value=instance_id) mocker.patch.object(loader_mod, "get_instance_id", mocked_get_instance_id) - etcd = AsyncEtcd.initialize(bootstrap_config.etcd.to_dataclass()) - etcd_loader = LegacyEtcdLoader(etcd) - - await etcd_loader.register_myself() - assert mocked_get_instance_id.await_count == 1 - data = await etcd.get_prefix(f"nodes/manager/{instance_id}") - assert data[""] == "up" - - await etcd_loader.deregister_myself() - assert mocked_get_instance_id.await_count == 2 - data = await etcd.get_prefix(f"nodes/manager/{instance_id}") - assert len(data) == 0 + async with AsyncEtcd.create_from_config(bootstrap_config.etcd.to_dataclass()) as etcd: + etcd_loader = LegacyEtcdLoader(etcd) + + await etcd_loader.register_myself() + assert mocked_get_instance_id.await_count == 1 + data = await etcd.get_prefix(f"nodes/manager/{instance_id}") + assert data[""] == "up" + + await etcd_loader.deregister_myself() + assert mocked_get_instance_id.await_count == 2 + data = await etcd.get_prefix(f"nodes/manager/{instance_id}") + assert len(data) == 0 diff --git a/tests/component/manager/conftest.py b/tests/component/manager/conftest.py index 94b3d21a0c0..76151679661 100644 --- a/tests/component/manager/conftest.py +++ b/tests/component/manager/conftest.py @@ -412,13 +412,13 @@ async def unified_config( app, bootstrap_config: BootstrapConfig, etcd_fixture ) -> AsyncIterator[ManagerUnifiedConfig]: root_ctx: RootContext = app["_root.context"] - etcd = AsyncEtcd.initialize(bootstrap_config.etcd.to_dataclass()) - root_ctx.etcd = etcd - etcd_loader = LegacyEtcdLoader(root_ctx.etcd) - raw_config = await etcd_loader.load() - merged_config = {**bootstrap_config.model_dump(), **raw_config} - unified_config = ManagerUnifiedConfig(**merged_config) - yield unified_config + async with AsyncEtcd.create_from_config(bootstrap_config.etcd.to_dataclass()) as etcd: + root_ctx.etcd = etcd + etcd_loader = LegacyEtcdLoader(root_ctx.etcd) + raw_config = await etcd_loader.load() + merged_config = {**bootstrap_config.model_dump(), **raw_config} + unified_config = ManagerUnifiedConfig(**merged_config) + yield unified_config @pytest.fixture(scope="session") diff --git a/tests/conftest.py b/tests/conftest.py index 680f7345020..255ac566c83 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,7 +45,7 @@ async def etcd( Shared etcd fixture for all tests. Creates a real AsyncEtcd client with proper scope prefixing. """ - etcd = AsyncEtcd( + async with AsyncEtcd( addrs=[etcd_container[1].to_legacy()], namespace=test_ns, scope_prefix_map={ @@ -53,15 +53,11 @@ async def etcd( ConfigScopes.SGROUP: "sgroup/testing", ConfigScopes.NODE: "node/i-test", }, - ) - try: + ) as etcd: await etcd.delete_prefix("", scope=ConfigScopes.GLOBAL) await etcd.delete_prefix("", scope=ConfigScopes.SGROUP) await etcd.delete_prefix("", scope=ConfigScopes.NODE) yield etcd - finally: await etcd.delete_prefix("", scope=ConfigScopes.GLOBAL) await etcd.delete_prefix("", scope=ConfigScopes.SGROUP) await etcd.delete_prefix("", scope=ConfigScopes.NODE) - await etcd.close() - del etcd diff --git a/tests/unit/common/conftest.py b/tests/unit/common/conftest.py index 7f2ce13bf58..ccb3e2799fa 100644 --- a/tests/unit/common/conftest.py +++ b/tests/unit/common/conftest.py @@ -183,19 +183,16 @@ async def test_valkey_stream_mq(redis_container, test_node_id) -> AsyncIterator[ @pytest.fixture async def gateway_etcd(etcd_container, test_ns) -> AsyncIterator[AsyncEtcd]: # noqa: F811 - etcd = AsyncEtcd( + async with AsyncEtcd( addrs=[etcd_container[1]], namespace=test_ns, scope_prefix_map={ ConfigScopes.GLOBAL: "", }, - ) - try: + ) as etcd: await etcd.delete_prefix("", scope=ConfigScopes.GLOBAL) yield etcd - finally: await etcd.delete_prefix("", scope=ConfigScopes.GLOBAL) - del etcd @pytest.fixture diff --git a/tests/unit/common/health_checker/checkers/test_etcd.py b/tests/unit/common/health_checker/checkers/test_etcd.py index 8e18e6bf3ef..ece84499373 100644 --- a/tests/unit/common/health_checker/checkers/test_etcd.py +++ b/tests/unit/common/health_checker/checkers/test_etcd.py @@ -28,16 +28,12 @@ async def etcd_client( ConfigScopes.NODE: f"nodes/test/{test_ns}", } - etcd = AsyncEtcd( + async with AsyncEtcd( [HostPortPair(etcd_addr.host, etcd_addr.port)], namespace=test_ns, scope_prefix_map=scope_prefix_map, - ) - - try: + ) as etcd: yield etcd - finally: - await etcd.close() @pytest.mark.asyncio async def test_success(self, etcd_client: AsyncEtcd) -> None: @@ -74,13 +70,11 @@ async def test_connection_error(self) -> None: ConfigScopes.GLOBAL: "", } - etcd = AsyncEtcd( + async with AsyncEtcd( [HostPortPair("localhost", 99999)], namespace="test", scope_prefix_map=scope_prefix_map, - ) - - try: + ) as etcd: checker = EtcdHealthChecker( etcd=etcd, timeout=1.0, @@ -91,8 +85,6 @@ async def test_connection_error(self) -> None: status = result.results[list(result.results.keys())[0]] assert not status.is_healthy assert status.error_message is not None - finally: - await etcd.close() @pytest.mark.asyncio async def test_multiple_checks(self, etcd_client: AsyncEtcd) -> None: diff --git a/tests/unit/common/test_distributed.py b/tests/unit/common/test_distributed.py index bd4f29f41bd..f7cb4ba1adf 100644 --- a/tests/unit/common/test_distributed.py +++ b/tests/unit/common/test_distributed.py @@ -206,10 +206,28 @@ async def _tick(context: Any, source: AgentId, event: NoopAnycastEvent) -> None: await event_dispatcher.start() await asyncio.sleep(0.1) # Allow dispatcher to start - etcd_lock: AbstractDistributedLock + async def _lock_test(dist_lock: AbstractDistributedLock) -> None: + # Common test logic for distributed lock + timer = GlobalTimer( + dist_lock, + event_producer, + lambda: NoopAnycastEvent(timer_ctx.test_case_ns), + timer_ctx.interval, + ) + try: + await timer.join() + while not stop_event.is_set(): + await asyncio.sleep(0) + finally: + await timer.leave() + await event_dispatcher.close() + await event_producer.close() + await redis_mq.close() + await asyncio.sleep(0.2) # Allow cleanup to complete + match etcd_client: case "etcd-client-py": - etcd = AsyncEtcd( + async with AsyncEtcd( addrs=etcd_ctx.addrs, namespace=etcd_ctx.namespace, scope_prefix_map={ @@ -217,25 +235,9 @@ async def _tick(context: Any, source: AgentId, event: NoopAnycastEvent) -> None: ConfigScopes.SGROUP: "sgroup/testing", ConfigScopes.NODE: "node/i-test", }, - ) - etcd_lock = EtcdLock(etcd_ctx.lock_name, etcd, timeout=None, debug=True) - - timer = GlobalTimer( - etcd_lock, - event_producer, - lambda: NoopAnycastEvent(timer_ctx.test_case_ns), - timer_ctx.interval, - ) - try: - await timer.join() - while not stop_event.is_set(): - await asyncio.sleep(0) - finally: - await timer.leave() - await event_dispatcher.close() - await event_producer.close() - await redis_mq.close() - await asyncio.sleep(0.2) # Allow cleanup to complete + ) as etcd: + etcd_lock = EtcdLock(etcd_ctx.lock_name, etcd, timeout=None, debug=True) + await _lock_test(etcd_lock) asyncio.run(_main()) diff --git a/tests/unit/storage-proxy/conftest.py b/tests/unit/storage-proxy/conftest.py index 835e87f883b..0c3f0cbcd65 100644 --- a/tests/unit/storage-proxy/conftest.py +++ b/tests/unit/storage-proxy/conftest.py @@ -40,14 +40,15 @@ def local_volume(vfroot) -> Iterator[Path]: @pytest.fixture -def mock_etcd() -> Iterator[AsyncEtcd]: - yield AsyncEtcd( +async def mock_etcd() -> AsyncIterator[AsyncEtcd]: + async with AsyncEtcd( addrs=[HostPortPair("", 0)], namespace="", scope_prefix_map={ ConfigScopes.GLOBAL: "", }, - ) + ) as etcd: + yield etcd def has_backend(backend_name: str) -> dict[str, Any] | None: diff --git a/tests/unit/storage/dependencies/infrastructure/test_redis.py b/tests/unit/storage/dependencies/infrastructure/test_redis.py index 94733009292..67573f76b71 100644 --- a/tests/unit/storage/dependencies/infrastructure/test_redis.py +++ b/tests/unit/storage/dependencies/infrastructure/test_redis.py @@ -48,8 +48,7 @@ async def etcd_client( redis_container_id, redis_addr = redis_container - etcd = make_etcd(storage_config) - try: + async with make_etcd(storage_config) as etcd: # Store redis config in etcd for RedisProvider await etcd.put_prefix( "config/redis", @@ -61,8 +60,6 @@ async def etcd_client( }, ) yield etcd - finally: - await etcd.close() @pytest.mark.integration @pytest.mark.asyncio