Skip to content

Commit 8fac038

Browse files
pymilvus-botjac0626silas.jiang
authored
[Backport 2.6] Refactor: Introduce thread-safe GlobalSchemaCache for gRPC handlers (#3188) (#3193)
Backport of #3188 to `2.6`. Signed-off-by: silas.jiang <silas.jiang@zilliz.com> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: jac <jacllovey@qq.com> Co-authored-by: silas.jiang <silas.jiang@zilliz.com>
1 parent d8a452b commit 8fac038

File tree

10 files changed

+524
-188
lines changed

10 files changed

+524
-188
lines changed

pymilvus/client/async_grpc_handler.py

Lines changed: 49 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
from grpc._cython import cygrpc
1212

1313
from pymilvus.client.types import GrantInfo, ResourceGroupConfig
14-
from pymilvus.decorators import ignore_unimplemented, retry_on_rpc_failure
14+
from pymilvus.decorators import (
15+
ignore_unimplemented,
16+
retry_on_rpc_failure,
17+
retry_on_schema_mismatch,
18+
)
1519
from pymilvus.exceptions import (
1620
AmbiguousIndexName,
17-
DataNotMatchException,
1821
DescribeCollectionException,
1922
ErrorCode,
2023
ExceptionsMessage,
@@ -29,6 +32,7 @@
2932
from . import entity_helper, ts_utils, utils
3033
from .abstract import AnnSearchRequest, BaseRanker, CollectionSchema, FieldSchema, MutationResult
3134
from .async_interceptor import async_header_adder_interceptor
35+
from .cache import GlobalCache
3236
from .check import (
3337
check_id_and_data,
3438
check_pass_param,
@@ -73,7 +77,6 @@ def __init__(
7377
) -> None:
7478
self._async_stub = None
7579
self._async_channel = channel
76-
self.schema_cache: Dict[str, dict] = {}
7780

7881
addr = kwargs.get("address")
7982
self._address = addr if addr is not None else self.__get_address(uri, host, port)
@@ -137,15 +140,10 @@ def _setup_authorization_interceptor(self, user: str, password: str, token: str)
137140

138141
def _setup_db_name(self, db_name: str):
139142
if db_name is None:
140-
new_db = None
143+
self._db_name = ""
141144
else:
142145
check_pass_param(db_name=db_name)
143-
new_db = db_name
144-
145-
if getattr(self, "_db_name", None) != new_db:
146-
self.schema_cache.clear()
147-
148-
self._db_name = new_db
146+
self._db_name = db_name
149147

150148
def _setup_grpc_channel(self, **kwargs):
151149
if self._async_channel is None:
@@ -270,6 +268,8 @@ async def drop_collection(
270268
request, timeout=timeout, metadata=_api_level_md(**kwargs)
271269
)
272270
check_status(response)
271+
# Invalidate global schema cache
272+
self._invalidate_schema(collection_name)
273273

274274
@retry_on_rpc_failure()
275275
async def load_collection(
@@ -520,46 +520,47 @@ async def rename_collection(
520520

521521
async def _get_info(self, collection_name: str, timeout: Optional[float] = None, **kwargs):
522522
schema = kwargs.get("schema")
523-
schema, schema_timestamp = await self._get_schema_from_cache_or_remote(
524-
collection_name, schema=schema, timeout=timeout, **kwargs
523+
schema, schema_timestamp = await self._get_schema(
524+
collection_name, timeout=timeout, **kwargs
525525
)
526526
fields_info = schema.get("fields")
527527
struct_fields_info = schema.get("struct_array_fields", [])
528528
enable_dynamic = schema.get("enable_dynamic_field", False)
529529

530530
return fields_info, struct_fields_info, enable_dynamic, schema_timestamp
531531

532-
async def update_schema(
533-
self, collection_name: str, timeout: Optional[float] = None, **kwargs
534-
) -> dict:
535-
self.schema_cache.pop(collection_name, None)
536-
schema = await self.describe_collection(collection_name, timeout=timeout, **kwargs)
537-
schema_timestamp = schema.get("update_timestamp", 0)
538-
self.schema_cache[collection_name] = {
539-
"schema": schema,
540-
"schema_timestamp": schema_timestamp,
541-
}
542-
return schema
543-
544-
async def _get_schema_from_cache_or_remote(
532+
async def _get_schema(
545533
self,
546534
collection_name: str,
547-
schema: Optional[dict] = None,
548535
timeout: Optional[float] = None,
549536
**kwargs,
550537
) -> Tuple[dict, int]:
551-
if collection_name in self.schema_cache:
552-
cached = self.schema_cache[collection_name]
553-
return cached["schema"], cached["schema_timestamp"]
554-
555-
if not isinstance(schema, dict):
556-
schema = await self.describe_collection(collection_name, timeout=timeout, **kwargs)
557-
schema_timestamp = schema.get("update_timestamp", 0)
558-
self.schema_cache[collection_name] = {
559-
"schema": schema,
560-
"schema_timestamp": schema_timestamp,
561-
}
562-
return schema, schema_timestamp
538+
"""
539+
Get collection schema, using cache when available.
540+
541+
Returns:
542+
Tuple of (schema_dict, schema_timestamp)
543+
"""
544+
cache = GlobalCache.schema
545+
endpoint = self.server_address
546+
db_name = self._db_name or ""
547+
548+
cached = cache.get(endpoint, db_name, collection_name)
549+
if cached is not None:
550+
return cached, cached.get("update_timestamp", 0)
551+
552+
# Fetch from server and cache
553+
schema = await self.describe_collection(collection_name, timeout=timeout, **kwargs)
554+
cache.set(endpoint, db_name, collection_name, schema)
555+
return schema, schema.get("update_timestamp", 0)
556+
557+
def _invalidate_schema(self, collection_name: str) -> None:
558+
"""Invalidate cached schema for a collection."""
559+
GlobalCache.schema.invalidate(self.server_address, self._db_name or "", collection_name)
560+
561+
def _invalidate_db_schemas(self, db_name: str) -> None:
562+
"""Invalidate all cached schemas for a database."""
563+
GlobalCache.schema.invalidate_db(self.server_address, db_name)
563564

564565
@retry_on_rpc_failure()
565566
async def release_collection(
@@ -574,6 +575,7 @@ async def release_collection(
574575
check_status(response)
575576

576577
@retry_on_rpc_failure()
578+
@retry_on_schema_mismatch()
577579
async def insert_rows(
578580
self,
579581
collection_name: str,
@@ -584,26 +586,12 @@ async def insert_rows(
584586
**kwargs,
585587
):
586588
await self.ensure_channel_ready()
587-
try:
588-
request = await self._prepare_row_insert_request(
589-
collection_name, entities, partition_name, schema, timeout, **kwargs
590-
)
591-
except DataNotMatchException:
592-
schema = await self.update_schema(collection_name, timeout, **kwargs)
593-
request = await self._prepare_row_insert_request(
594-
collection_name, entities, partition_name, schema, timeout, **kwargs
595-
)
589+
request = await self._prepare_row_insert_request(
590+
collection_name, entities, partition_name, schema, timeout, **kwargs
591+
)
596592
resp = await self._async_stub.Insert(
597593
request=request, timeout=timeout, metadata=_api_level_md(**kwargs)
598594
)
599-
if resp.status.error_code == common_pb2.SchemaMismatch:
600-
schema = await self.update_schema(collection_name, timeout, **kwargs)
601-
request = await self._prepare_row_insert_request(
602-
collection_name, entities, partition_name, schema, timeout, **kwargs
603-
)
604-
resp = await self._async_stub.Insert(
605-
request=request, timeout=timeout, metadata=_api_level_md(**kwargs)
606-
)
607595
check_status(resp.status)
608596
ts_utils.update_collection_ts(collection_name, resp.timestamp)
609597
return MutationResult(resp)
@@ -620,7 +608,7 @@ async def _prepare_row_insert_request(
620608
if isinstance(entity_rows, dict):
621609
entity_rows = [entity_rows]
622610

623-
schema, schema_timestamp = await self._get_schema_from_cache_or_remote(
611+
schema, schema_timestamp = await self._get_schema(
624612
collection_name, schema=schema, timeout=timeout, **kwargs
625613
)
626614
fields_info = schema.get("fields")
@@ -688,7 +676,7 @@ async def _prepare_batch_upsert_request(
688676
partial_update = kwargs.get("partial_update", False)
689677

690678
schema = kwargs.get("schema")
691-
schema, _ = await self._get_schema_from_cache_or_remote(
679+
schema, _ = await self._get_schema(
692680
collection_name, schema=schema, timeout=timeout, **kwargs
693681
)
694682

@@ -757,6 +745,7 @@ async def _prepare_row_upsert_request(
757745
)
758746

759747
@retry_on_rpc_failure()
748+
@retry_on_schema_mismatch()
760749
async def upsert_rows(
761750
self,
762751
collection_name: str,
@@ -768,26 +757,12 @@ async def upsert_rows(
768757
await self.ensure_channel_ready()
769758
if isinstance(entities, dict):
770759
entities = [entities]
771-
try:
772-
request = await self._prepare_row_upsert_request(
773-
collection_name, entities, partition_name, timeout, **kwargs
774-
)
775-
except DataNotMatchException:
776-
schema = await self.update_schema(collection_name, timeout, **kwargs)
777-
request = await self._prepare_row_upsert_request(
778-
collection_name, entities, partition_name, timeout, schema=schema, **kwargs
779-
)
760+
request = await self._prepare_row_upsert_request(
761+
collection_name, entities, partition_name, timeout, **kwargs
762+
)
780763
response = await self._async_stub.Upsert(
781764
request, timeout=timeout, metadata=_api_level_md(**kwargs)
782765
)
783-
if response.status.error_code == common_pb2.SchemaMismatch:
784-
schema = await self.update_schema(collection_name, timeout, **kwargs)
785-
request = await self._prepare_row_upsert_request(
786-
collection_name, entities, partition_name, timeout, schema=schema, **kwargs
787-
)
788-
response = await self._async_stub.Upsert(
789-
request, timeout=timeout, metadata=_api_level_md(**kwargs)
790-
)
791766
check_status(response.status)
792767
m = MutationResult(response)
793768
ts_utils.update_collection_ts(collection_name, m.timestamp)

pymilvus/client/cache.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import logging
2+
import threading
3+
from typing import Any, ClassVar, Optional, Tuple
4+
5+
from cachetools import LRUCache
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
class CacheRegion:
11+
"""
12+
Thread-safe LRU cache base class.
13+
14+
Subclasses should define specific key types and value types.
15+
"""
16+
17+
DEFAULT_CAPACITY = 4096
18+
19+
def __init__(self, capacity: int = DEFAULT_CAPACITY):
20+
self._cache: LRUCache = LRUCache(maxsize=capacity)
21+
self._lock = threading.Lock()
22+
23+
def get(self, key: Any) -> Optional[Any]:
24+
"""Get value from cache. Returns None if not found."""
25+
with self._lock:
26+
return self._cache.get(key)
27+
28+
def set(self, key: Any, value: Any) -> None:
29+
"""Set value in cache. Evicts LRU entry if over capacity."""
30+
with self._lock:
31+
self._cache[key] = value
32+
33+
def invalidate(self, key: Any) -> None:
34+
"""Remove a specific key from cache."""
35+
with self._lock:
36+
self._cache.pop(key, None)
37+
38+
def clear(self) -> None:
39+
"""Clear all entries from cache."""
40+
with self._lock:
41+
self._cache.clear()
42+
43+
def __len__(self) -> int:
44+
"""Return number of cached entries."""
45+
with self._lock:
46+
return len(self._cache)
47+
48+
49+
class SchemaCache(CacheRegion):
50+
"""
51+
Schema-specific cache with tuple-based keys.
52+
53+
Key: (endpoint, db_name, collection_name)
54+
Value: schema dict
55+
"""
56+
57+
def get(self, endpoint: str, db_name: str, collection_name: str) -> Optional[dict]:
58+
"""Get schema from cache."""
59+
key = self._make_key(endpoint, db_name, collection_name)
60+
return super().get(key)
61+
62+
def set(self, endpoint: str, db_name: str, collection_name: str, schema: dict) -> None:
63+
"""Set schema in cache."""
64+
key = self._make_key(endpoint, db_name, collection_name)
65+
super().set(key, schema)
66+
67+
def invalidate(self, endpoint: str, db_name: str, collection_name: str) -> None:
68+
"""Invalidate schema for a specific collection."""
69+
key = self._make_key(endpoint, db_name, collection_name)
70+
super().invalidate(key)
71+
72+
def invalidate_db(self, endpoint: str, db_name: str) -> None:
73+
"""Invalidate all schemas for a database."""
74+
prefix = (endpoint, db_name or "default")
75+
with self._lock:
76+
keys_to_remove = [k for k in self._cache if k[:2] == prefix]
77+
for key in keys_to_remove:
78+
self._cache.pop(key, None)
79+
80+
@staticmethod
81+
def _make_key(endpoint: str, db_name: str, collection_name: str) -> Tuple[str, str, str]:
82+
"""Create tuple key from components."""
83+
db = db_name if db_name else "default"
84+
return (endpoint, db, collection_name)
85+
86+
87+
class GlobalCache:
88+
"""
89+
Global access point for all cache instances.
90+
91+
Usage:
92+
GlobalCache.schema.get(endpoint, db_name, collection_name)
93+
GlobalCache.schema.set(endpoint, db_name, collection_name, schema)
94+
"""
95+
96+
schema: ClassVar[SchemaCache] = SchemaCache()
97+
98+
@classmethod
99+
def _reset_for_testing(cls) -> None:
100+
"""Reset cache for testing. Creates new instances."""
101+
cls.schema = SchemaCache()

0 commit comments

Comments
 (0)