1111from grpc ._cython import cygrpc
1212
1313from pymilvus .client .types import GrantInfo , ResourceGroupConfig
14- from pymilvus .decorators import ignore_unimplemented , retry_on_rpc_failure
14+ from pymilvus .decorators import (
15+ ignore_unimplemented ,
16+ retry_on_rpc_failure ,
17+ retry_on_schema_mismatch ,
18+ )
1519from pymilvus .exceptions import (
1620 AmbiguousIndexName ,
17- DataNotMatchException ,
1821 DescribeCollectionException ,
1922 ErrorCode ,
2023 ExceptionsMessage ,
2932from . import entity_helper , ts_utils , utils
3033from .abstract import AnnSearchRequest , BaseRanker , CollectionSchema , FieldSchema , MutationResult
3134from .async_interceptor import async_header_adder_interceptor
35+ from .cache import GlobalCache
3236from .check import (
3337 check_id_and_data ,
3438 check_pass_param ,
@@ -73,7 +77,6 @@ def __init__(
7377 ) -> None :
7478 self ._async_stub = None
7579 self ._async_channel = channel
76- self .schema_cache : Dict [str , dict ] = {}
7780
7881 addr = kwargs .get ("address" )
7982 self ._address = addr if addr is not None else self .__get_address (uri , host , port )
@@ -137,15 +140,10 @@ def _setup_authorization_interceptor(self, user: str, password: str, token: str)
137140
138141 def _setup_db_name (self , db_name : str ):
139142 if db_name is None :
140- new_db = None
143+ self . _db_name = ""
141144 else :
142145 check_pass_param (db_name = db_name )
143- new_db = db_name
144-
145- if getattr (self , "_db_name" , None ) != new_db :
146- self .schema_cache .clear ()
147-
148- self ._db_name = new_db
146+ self ._db_name = db_name
149147
150148 def _setup_grpc_channel (self , ** kwargs ):
151149 if self ._async_channel is None :
@@ -270,6 +268,8 @@ async def drop_collection(
270268 request , timeout = timeout , metadata = _api_level_md (** kwargs )
271269 )
272270 check_status (response )
271+ # Invalidate global schema cache
272+ self ._invalidate_schema (collection_name )
273273
274274 @retry_on_rpc_failure ()
275275 async def load_collection (
@@ -520,46 +520,47 @@ async def rename_collection(
520520
521521 async def _get_info (self , collection_name : str , timeout : Optional [float ] = None , ** kwargs ):
522522 schema = kwargs .get ("schema" )
523- schema , schema_timestamp = await self ._get_schema_from_cache_or_remote (
524- collection_name , schema = schema , timeout = timeout , ** kwargs
523+ schema , schema_timestamp = await self ._get_schema (
524+ collection_name , timeout = timeout , ** kwargs
525525 )
526526 fields_info = schema .get ("fields" )
527527 struct_fields_info = schema .get ("struct_array_fields" , [])
528528 enable_dynamic = schema .get ("enable_dynamic_field" , False )
529529
530530 return fields_info , struct_fields_info , enable_dynamic , schema_timestamp
531531
532- async def update_schema (
533- self , collection_name : str , timeout : Optional [float ] = None , ** kwargs
534- ) -> dict :
535- self .schema_cache .pop (collection_name , None )
536- schema = await self .describe_collection (collection_name , timeout = timeout , ** kwargs )
537- schema_timestamp = schema .get ("update_timestamp" , 0 )
538- self .schema_cache [collection_name ] = {
539- "schema" : schema ,
540- "schema_timestamp" : schema_timestamp ,
541- }
542- return schema
543-
544- async def _get_schema_from_cache_or_remote (
532+ async def _get_schema (
545533 self ,
546534 collection_name : str ,
547- schema : Optional [dict ] = None ,
548535 timeout : Optional [float ] = None ,
549536 ** kwargs ,
550537 ) -> Tuple [dict , int ]:
551- if collection_name in self .schema_cache :
552- cached = self .schema_cache [collection_name ]
553- return cached ["schema" ], cached ["schema_timestamp" ]
554-
555- if not isinstance (schema , dict ):
556- schema = await self .describe_collection (collection_name , timeout = timeout , ** kwargs )
557- schema_timestamp = schema .get ("update_timestamp" , 0 )
558- self .schema_cache [collection_name ] = {
559- "schema" : schema ,
560- "schema_timestamp" : schema_timestamp ,
561- }
562- return schema , schema_timestamp
538+ """
539+ Get collection schema, using cache when available.
540+
541+ Returns:
542+ Tuple of (schema_dict, schema_timestamp)
543+ """
544+ cache = GlobalCache .schema
545+ endpoint = self .server_address
546+ db_name = self ._db_name or ""
547+
548+ cached = cache .get (endpoint , db_name , collection_name )
549+ if cached is not None :
550+ return cached , cached .get ("update_timestamp" , 0 )
551+
552+ # Fetch from server and cache
553+ schema = await self .describe_collection (collection_name , timeout = timeout , ** kwargs )
554+ cache .set (endpoint , db_name , collection_name , schema )
555+ return schema , schema .get ("update_timestamp" , 0 )
556+
557+ def _invalidate_schema (self , collection_name : str ) -> None :
558+ """Invalidate cached schema for a collection."""
559+ GlobalCache .schema .invalidate (self .server_address , self ._db_name or "" , collection_name )
560+
561+ def _invalidate_db_schemas (self , db_name : str ) -> None :
562+ """Invalidate all cached schemas for a database."""
563+ GlobalCache .schema .invalidate_db (self .server_address , db_name )
563564
564565 @retry_on_rpc_failure ()
565566 async def release_collection (
@@ -574,6 +575,7 @@ async def release_collection(
574575 check_status (response )
575576
576577 @retry_on_rpc_failure ()
578+ @retry_on_schema_mismatch ()
577579 async def insert_rows (
578580 self ,
579581 collection_name : str ,
@@ -584,26 +586,12 @@ async def insert_rows(
584586 ** kwargs ,
585587 ):
586588 await self .ensure_channel_ready ()
587- try :
588- request = await self ._prepare_row_insert_request (
589- collection_name , entities , partition_name , schema , timeout , ** kwargs
590- )
591- except DataNotMatchException :
592- schema = await self .update_schema (collection_name , timeout , ** kwargs )
593- request = await self ._prepare_row_insert_request (
594- collection_name , entities , partition_name , schema , timeout , ** kwargs
595- )
589+ request = await self ._prepare_row_insert_request (
590+ collection_name , entities , partition_name , schema , timeout , ** kwargs
591+ )
596592 resp = await self ._async_stub .Insert (
597593 request = request , timeout = timeout , metadata = _api_level_md (** kwargs )
598594 )
599- if resp .status .error_code == common_pb2 .SchemaMismatch :
600- schema = await self .update_schema (collection_name , timeout , ** kwargs )
601- request = await self ._prepare_row_insert_request (
602- collection_name , entities , partition_name , schema , timeout , ** kwargs
603- )
604- resp = await self ._async_stub .Insert (
605- request = request , timeout = timeout , metadata = _api_level_md (** kwargs )
606- )
607595 check_status (resp .status )
608596 ts_utils .update_collection_ts (collection_name , resp .timestamp )
609597 return MutationResult (resp )
@@ -620,7 +608,7 @@ async def _prepare_row_insert_request(
620608 if isinstance (entity_rows , dict ):
621609 entity_rows = [entity_rows ]
622610
623- schema , schema_timestamp = await self ._get_schema_from_cache_or_remote (
611+ schema , schema_timestamp = await self ._get_schema (
624612 collection_name , schema = schema , timeout = timeout , ** kwargs
625613 )
626614 fields_info = schema .get ("fields" )
@@ -688,7 +676,7 @@ async def _prepare_batch_upsert_request(
688676 partial_update = kwargs .get ("partial_update" , False )
689677
690678 schema = kwargs .get ("schema" )
691- schema , _ = await self ._get_schema_from_cache_or_remote (
679+ schema , _ = await self ._get_schema (
692680 collection_name , schema = schema , timeout = timeout , ** kwargs
693681 )
694682
@@ -757,6 +745,7 @@ async def _prepare_row_upsert_request(
757745 )
758746
759747 @retry_on_rpc_failure ()
748+ @retry_on_schema_mismatch ()
760749 async def upsert_rows (
761750 self ,
762751 collection_name : str ,
@@ -768,26 +757,12 @@ async def upsert_rows(
768757 await self .ensure_channel_ready ()
769758 if isinstance (entities , dict ):
770759 entities = [entities ]
771- try :
772- request = await self ._prepare_row_upsert_request (
773- collection_name , entities , partition_name , timeout , ** kwargs
774- )
775- except DataNotMatchException :
776- schema = await self .update_schema (collection_name , timeout , ** kwargs )
777- request = await self ._prepare_row_upsert_request (
778- collection_name , entities , partition_name , timeout , schema = schema , ** kwargs
779- )
760+ request = await self ._prepare_row_upsert_request (
761+ collection_name , entities , partition_name , timeout , ** kwargs
762+ )
780763 response = await self ._async_stub .Upsert (
781764 request , timeout = timeout , metadata = _api_level_md (** kwargs )
782765 )
783- if response .status .error_code == common_pb2 .SchemaMismatch :
784- schema = await self .update_schema (collection_name , timeout , ** kwargs )
785- request = await self ._prepare_row_upsert_request (
786- collection_name , entities , partition_name , timeout , schema = schema , ** kwargs
787- )
788- response = await self ._async_stub .Upsert (
789- request , timeout = timeout , metadata = _api_level_md (** kwargs )
790- )
791766 check_status (response .status )
792767 m = MutationResult (response )
793768 ts_utils .update_collection_ts (collection_name , m .timestamp )
0 commit comments