Skip to content

Commit 27d1440

Browse files
jac0626silas.jiang
andauthored
Refactor: Introduce thread-safe GlobalSchemaCache for gRPC handlers (#3188)
- Global Singleton: Replaced instance-level caches with a thread-safe GlobalSchemaCache - LRU Eviction: Implemented LRU logic with a 4096 capacity. - Efficiency: Reduced redundant describe_collection calls across client instances. see #3186 --------- Signed-off-by: silas.jiang <silas.jiang@zilliz.com> Co-authored-by: silas.jiang <silas.jiang@zilliz.com>
1 parent b1bf65d commit 27d1440

File tree

10 files changed

+524
-188
lines changed

10 files changed

+524
-188
lines changed

pymilvus/client/async_grpc_handler.py

Lines changed: 49 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,13 @@
1111
from grpc._cython import cygrpc
1212

1313
from pymilvus.client.types import GrantInfo, ResourceGroupConfig
14-
from pymilvus.decorators import ignore_unimplemented, retry_on_rpc_failure
14+
from pymilvus.decorators import (
15+
ignore_unimplemented,
16+
retry_on_rpc_failure,
17+
retry_on_schema_mismatch,
18+
)
1519
from pymilvus.exceptions import (
1620
AmbiguousIndexName,
17-
DataNotMatchException,
1821
DescribeCollectionException,
1922
ErrorCode,
2023
ExceptionsMessage,
@@ -29,6 +32,7 @@
2932
from . import entity_helper, ts_utils, utils
3033
from .abstract import AnnSearchRequest, BaseRanker, CollectionSchema, FieldSchema, MutationResult
3134
from .async_interceptor import async_header_adder_interceptor
35+
from .cache import GlobalCache
3236
from .check import (
3337
check_id_and_data,
3438
check_pass_param,
@@ -73,7 +77,6 @@ def __init__(
7377
) -> None:
7478
self._async_stub = None
7579
self._async_channel = channel
76-
self.schema_cache: Dict[str, dict] = {}
7780

7881
addr = kwargs.get("address")
7982
self._address = addr if addr is not None else self.__get_address(uri, host, port)
@@ -137,15 +140,10 @@ def _setup_authorization_interceptor(self, user: str, password: str, token: str)
137140

138141
def _setup_db_name(self, db_name: str):
139142
if db_name is None:
140-
new_db = None
143+
self._db_name = ""
141144
else:
142145
check_pass_param(db_name=db_name)
143-
new_db = db_name
144-
145-
if getattr(self, "_db_name", None) != new_db:
146-
self.schema_cache.clear()
147-
148-
self._db_name = new_db
146+
self._db_name = db_name
149147

150148
def _setup_grpc_channel(self, **kwargs):
151149
if self._async_channel is None:
@@ -270,6 +268,8 @@ async def drop_collection(
270268
request, timeout=timeout, metadata=_api_level_md(**kwargs)
271269
)
272270
check_status(response)
271+
# Invalidate global schema cache
272+
self._invalidate_schema(collection_name)
273273

274274
@retry_on_rpc_failure()
275275
async def truncate_collection(
@@ -532,46 +532,47 @@ async def rename_collection(
532532

533533
async def _get_info(self, collection_name: str, timeout: Optional[float] = None, **kwargs):
534534
schema = kwargs.get("schema")
535-
schema, schema_timestamp = await self._get_schema_from_cache_or_remote(
536-
collection_name, schema=schema, timeout=timeout, **kwargs
535+
schema, schema_timestamp = await self._get_schema(
536+
collection_name, timeout=timeout, **kwargs
537537
)
538538
fields_info = schema.get("fields")
539539
struct_fields_info = schema.get("struct_array_fields", [])
540540
enable_dynamic = schema.get("enable_dynamic_field", False)
541541

542542
return fields_info, struct_fields_info, enable_dynamic, schema_timestamp
543543

544-
async def update_schema(
545-
self, collection_name: str, timeout: Optional[float] = None, **kwargs
546-
) -> dict:
547-
self.schema_cache.pop(collection_name, None)
548-
schema = await self.describe_collection(collection_name, timeout=timeout, **kwargs)
549-
schema_timestamp = schema.get("update_timestamp", 0)
550-
self.schema_cache[collection_name] = {
551-
"schema": schema,
552-
"schema_timestamp": schema_timestamp,
553-
}
554-
return schema
555-
556-
async def _get_schema_from_cache_or_remote(
544+
async def _get_schema(
557545
self,
558546
collection_name: str,
559-
schema: Optional[dict] = None,
560547
timeout: Optional[float] = None,
561548
**kwargs,
562549
) -> Tuple[dict, int]:
563-
if collection_name in self.schema_cache:
564-
cached = self.schema_cache[collection_name]
565-
return cached["schema"], cached["schema_timestamp"]
566-
567-
if not isinstance(schema, dict):
568-
schema = await self.describe_collection(collection_name, timeout=timeout, **kwargs)
569-
schema_timestamp = schema.get("update_timestamp", 0)
570-
self.schema_cache[collection_name] = {
571-
"schema": schema,
572-
"schema_timestamp": schema_timestamp,
573-
}
574-
return schema, schema_timestamp
550+
"""
551+
Get collection schema, using cache when available.
552+
553+
Returns:
554+
Tuple of (schema_dict, schema_timestamp)
555+
"""
556+
cache = GlobalCache.schema
557+
endpoint = self.server_address
558+
db_name = self._db_name or ""
559+
560+
cached = cache.get(endpoint, db_name, collection_name)
561+
if cached is not None:
562+
return cached, cached.get("update_timestamp", 0)
563+
564+
# Fetch from server and cache
565+
schema = await self.describe_collection(collection_name, timeout=timeout, **kwargs)
566+
cache.set(endpoint, db_name, collection_name, schema)
567+
return schema, schema.get("update_timestamp", 0)
568+
569+
def _invalidate_schema(self, collection_name: str) -> None:
570+
"""Invalidate cached schema for a collection."""
571+
GlobalCache.schema.invalidate(self.server_address, self._db_name or "", collection_name)
572+
573+
def _invalidate_db_schemas(self, db_name: str) -> None:
574+
"""Invalidate all cached schemas for a database."""
575+
GlobalCache.schema.invalidate_db(self.server_address, db_name)
575576

576577
@retry_on_rpc_failure()
577578
async def release_collection(
@@ -586,6 +587,7 @@ async def release_collection(
586587
check_status(response)
587588

588589
@retry_on_rpc_failure()
590+
@retry_on_schema_mismatch()
589591
async def insert_rows(
590592
self,
591593
collection_name: str,
@@ -596,26 +598,12 @@ async def insert_rows(
596598
**kwargs,
597599
):
598600
await self.ensure_channel_ready()
599-
try:
600-
request = await self._prepare_row_insert_request(
601-
collection_name, entities, partition_name, schema, timeout, **kwargs
602-
)
603-
except DataNotMatchException:
604-
schema = await self.update_schema(collection_name, timeout, **kwargs)
605-
request = await self._prepare_row_insert_request(
606-
collection_name, entities, partition_name, schema, timeout, **kwargs
607-
)
601+
request = await self._prepare_row_insert_request(
602+
collection_name, entities, partition_name, schema, timeout, **kwargs
603+
)
608604
resp = await self._async_stub.Insert(
609605
request=request, timeout=timeout, metadata=_api_level_md(**kwargs)
610606
)
611-
if resp.status.error_code == common_pb2.SchemaMismatch:
612-
schema = await self.update_schema(collection_name, timeout, **kwargs)
613-
request = await self._prepare_row_insert_request(
614-
collection_name, entities, partition_name, schema, timeout, **kwargs
615-
)
616-
resp = await self._async_stub.Insert(
617-
request=request, timeout=timeout, metadata=_api_level_md(**kwargs)
618-
)
619607
check_status(resp.status)
620608
ts_utils.update_collection_ts(collection_name, resp.timestamp)
621609
return MutationResult(resp)
@@ -632,7 +620,7 @@ async def _prepare_row_insert_request(
632620
if isinstance(entity_rows, dict):
633621
entity_rows = [entity_rows]
634622

635-
schema, schema_timestamp = await self._get_schema_from_cache_or_remote(
623+
schema, schema_timestamp = await self._get_schema(
636624
collection_name, schema=schema, timeout=timeout, **kwargs
637625
)
638626
fields_info = schema.get("fields")
@@ -700,7 +688,7 @@ async def _prepare_batch_upsert_request(
700688
partial_update = kwargs.get("partial_update", False)
701689

702690
schema = kwargs.get("schema")
703-
schema, _ = await self._get_schema_from_cache_or_remote(
691+
schema, _ = await self._get_schema(
704692
collection_name, schema=schema, timeout=timeout, **kwargs
705693
)
706694

@@ -770,6 +758,7 @@ async def _prepare_row_upsert_request(
770758
)
771759

772760
@retry_on_rpc_failure()
761+
@retry_on_schema_mismatch()
773762
async def upsert_rows(
774763
self,
775764
collection_name: str,
@@ -781,26 +770,12 @@ async def upsert_rows(
781770
await self.ensure_channel_ready()
782771
if isinstance(entities, dict):
783772
entities = [entities]
784-
try:
785-
request = await self._prepare_row_upsert_request(
786-
collection_name, entities, partition_name, timeout, **kwargs
787-
)
788-
except DataNotMatchException:
789-
schema = await self.update_schema(collection_name, timeout, **kwargs)
790-
request = await self._prepare_row_upsert_request(
791-
collection_name, entities, partition_name, timeout, schema=schema, **kwargs
792-
)
773+
request = await self._prepare_row_upsert_request(
774+
collection_name, entities, partition_name, timeout, **kwargs
775+
)
793776
response = await self._async_stub.Upsert(
794777
request, timeout=timeout, metadata=_api_level_md(**kwargs)
795778
)
796-
if response.status.error_code == common_pb2.SchemaMismatch:
797-
schema = await self.update_schema(collection_name, timeout, **kwargs)
798-
request = await self._prepare_row_upsert_request(
799-
collection_name, entities, partition_name, timeout, schema=schema, **kwargs
800-
)
801-
response = await self._async_stub.Upsert(
802-
request, timeout=timeout, metadata=_api_level_md(**kwargs)
803-
)
804779
check_status(response.status)
805780
m = MutationResult(response)
806781
ts_utils.update_collection_ts(collection_name, m.timestamp)

pymilvus/client/cache.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import logging
2+
import threading
3+
from typing import Any, ClassVar, Optional, Tuple
4+
5+
from cachetools import LRUCache
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
class CacheRegion:
11+
"""
12+
Thread-safe LRU cache base class.
13+
14+
Subclasses should define specific key types and value types.
15+
"""
16+
17+
DEFAULT_CAPACITY = 4096
18+
19+
def __init__(self, capacity: int = DEFAULT_CAPACITY):
20+
self._cache: LRUCache = LRUCache(maxsize=capacity)
21+
self._lock = threading.Lock()
22+
23+
def get(self, key: Any) -> Optional[Any]:
24+
"""Get value from cache. Returns None if not found."""
25+
with self._lock:
26+
return self._cache.get(key)
27+
28+
def set(self, key: Any, value: Any) -> None:
29+
"""Set value in cache. Evicts LRU entry if over capacity."""
30+
with self._lock:
31+
self._cache[key] = value
32+
33+
def invalidate(self, key: Any) -> None:
34+
"""Remove a specific key from cache."""
35+
with self._lock:
36+
self._cache.pop(key, None)
37+
38+
def clear(self) -> None:
39+
"""Clear all entries from cache."""
40+
with self._lock:
41+
self._cache.clear()
42+
43+
def __len__(self) -> int:
44+
"""Return number of cached entries."""
45+
with self._lock:
46+
return len(self._cache)
47+
48+
49+
class SchemaCache(CacheRegion):
50+
"""
51+
Schema-specific cache with tuple-based keys.
52+
53+
Key: (endpoint, db_name, collection_name)
54+
Value: schema dict
55+
"""
56+
57+
def get(self, endpoint: str, db_name: str, collection_name: str) -> Optional[dict]:
58+
"""Get schema from cache."""
59+
key = self._make_key(endpoint, db_name, collection_name)
60+
return super().get(key)
61+
62+
def set(self, endpoint: str, db_name: str, collection_name: str, schema: dict) -> None:
63+
"""Set schema in cache."""
64+
key = self._make_key(endpoint, db_name, collection_name)
65+
super().set(key, schema)
66+
67+
def invalidate(self, endpoint: str, db_name: str, collection_name: str) -> None:
68+
"""Invalidate schema for a specific collection."""
69+
key = self._make_key(endpoint, db_name, collection_name)
70+
super().invalidate(key)
71+
72+
def invalidate_db(self, endpoint: str, db_name: str) -> None:
73+
"""Invalidate all schemas for a database."""
74+
prefix = (endpoint, db_name or "default")
75+
with self._lock:
76+
keys_to_remove = [k for k in self._cache if k[:2] == prefix]
77+
for key in keys_to_remove:
78+
self._cache.pop(key, None)
79+
80+
@staticmethod
81+
def _make_key(endpoint: str, db_name: str, collection_name: str) -> Tuple[str, str, str]:
82+
"""Create tuple key from components."""
83+
db = db_name if db_name else "default"
84+
return (endpoint, db, collection_name)
85+
86+
87+
class GlobalCache:
88+
"""
89+
Global access point for all cache instances.
90+
91+
Usage:
92+
GlobalCache.schema.get(endpoint, db_name, collection_name)
93+
GlobalCache.schema.set(endpoint, db_name, collection_name, schema)
94+
"""
95+
96+
schema: ClassVar[SchemaCache] = SchemaCache()
97+
98+
@classmethod
99+
def _reset_for_testing(cls) -> None:
100+
"""Reset cache for testing. Creates new instances."""
101+
cls.schema = SchemaCache()

0 commit comments

Comments
 (0)