1111from grpc ._cython import cygrpc
1212
1313from pymilvus .client .types import GrantInfo , ResourceGroupConfig
14- from pymilvus .decorators import ignore_unimplemented , retry_on_rpc_failure
14+ from pymilvus .decorators import (
15+ ignore_unimplemented ,
16+ retry_on_rpc_failure ,
17+ retry_on_schema_mismatch ,
18+ )
1519from pymilvus .exceptions import (
1620 AmbiguousIndexName ,
17- DataNotMatchException ,
1821 DescribeCollectionException ,
1922 ErrorCode ,
2023 ExceptionsMessage ,
2932from . import entity_helper , ts_utils , utils
3033from .abstract import AnnSearchRequest , BaseRanker , CollectionSchema , FieldSchema , MutationResult
3134from .async_interceptor import async_header_adder_interceptor
35+ from .cache import GlobalCache
3236from .check import (
3337 check_id_and_data ,
3438 check_pass_param ,
@@ -73,7 +77,6 @@ def __init__(
7377 ) -> None :
7478 self ._async_stub = None
7579 self ._async_channel = channel
76- self .schema_cache : Dict [str , dict ] = {}
7780
7881 addr = kwargs .get ("address" )
7982 self ._address = addr if addr is not None else self .__get_address (uri , host , port )
@@ -137,15 +140,10 @@ def _setup_authorization_interceptor(self, user: str, password: str, token: str)
137140
138141 def _setup_db_name (self , db_name : str ):
139142 if db_name is None :
140- new_db = None
143+ self . _db_name = ""
141144 else :
142145 check_pass_param (db_name = db_name )
143- new_db = db_name
144-
145- if getattr (self , "_db_name" , None ) != new_db :
146- self .schema_cache .clear ()
147-
148- self ._db_name = new_db
146+ self ._db_name = db_name
149147
150148 def _setup_grpc_channel (self , ** kwargs ):
151149 if self ._async_channel is None :
@@ -270,6 +268,8 @@ async def drop_collection(
270268 request , timeout = timeout , metadata = _api_level_md (** kwargs )
271269 )
272270 check_status (response )
271+ # Invalidate global schema cache
272+ self ._invalidate_schema (collection_name )
273273
274274 @retry_on_rpc_failure ()
275275 async def truncate_collection (
@@ -532,46 +532,47 @@ async def rename_collection(
532532
533533 async def _get_info (self , collection_name : str , timeout : Optional [float ] = None , ** kwargs ):
534534 schema = kwargs .get ("schema" )
535- schema , schema_timestamp = await self ._get_schema_from_cache_or_remote (
536- collection_name , schema = schema , timeout = timeout , ** kwargs
535+ schema , schema_timestamp = await self ._get_schema (
536+ collection_name , timeout = timeout , ** kwargs
537537 )
538538 fields_info = schema .get ("fields" )
539539 struct_fields_info = schema .get ("struct_array_fields" , [])
540540 enable_dynamic = schema .get ("enable_dynamic_field" , False )
541541
542542 return fields_info , struct_fields_info , enable_dynamic , schema_timestamp
543543
544- async def update_schema (
545- self , collection_name : str , timeout : Optional [float ] = None , ** kwargs
546- ) -> dict :
547- self .schema_cache .pop (collection_name , None )
548- schema = await self .describe_collection (collection_name , timeout = timeout , ** kwargs )
549- schema_timestamp = schema .get ("update_timestamp" , 0 )
550- self .schema_cache [collection_name ] = {
551- "schema" : schema ,
552- "schema_timestamp" : schema_timestamp ,
553- }
554- return schema
555-
556- async def _get_schema_from_cache_or_remote (
544+ async def _get_schema (
557545 self ,
558546 collection_name : str ,
559- schema : Optional [dict ] = None ,
560547 timeout : Optional [float ] = None ,
561548 ** kwargs ,
562549 ) -> Tuple [dict , int ]:
563- if collection_name in self .schema_cache :
564- cached = self .schema_cache [collection_name ]
565- return cached ["schema" ], cached ["schema_timestamp" ]
566-
567- if not isinstance (schema , dict ):
568- schema = await self .describe_collection (collection_name , timeout = timeout , ** kwargs )
569- schema_timestamp = schema .get ("update_timestamp" , 0 )
570- self .schema_cache [collection_name ] = {
571- "schema" : schema ,
572- "schema_timestamp" : schema_timestamp ,
573- }
574- return schema , schema_timestamp
550+ """
551+ Get collection schema, using cache when available.
552+
553+ Returns:
554+ Tuple of (schema_dict, schema_timestamp)
555+ """
556+ cache = GlobalCache .schema
557+ endpoint = self .server_address
558+ db_name = self ._db_name or ""
559+
560+ cached = cache .get (endpoint , db_name , collection_name )
561+ if cached is not None :
562+ return cached , cached .get ("update_timestamp" , 0 )
563+
564+ # Fetch from server and cache
565+ schema = await self .describe_collection (collection_name , timeout = timeout , ** kwargs )
566+ cache .set (endpoint , db_name , collection_name , schema )
567+ return schema , schema .get ("update_timestamp" , 0 )
568+
569+ def _invalidate_schema (self , collection_name : str ) -> None :
570+ """Invalidate cached schema for a collection."""
571+ GlobalCache .schema .invalidate (self .server_address , self ._db_name or "" , collection_name )
572+
573+ def _invalidate_db_schemas (self , db_name : str ) -> None :
574+ """Invalidate all cached schemas for a database."""
575+ GlobalCache .schema .invalidate_db (self .server_address , db_name )
575576
576577 @retry_on_rpc_failure ()
577578 async def release_collection (
@@ -586,6 +587,7 @@ async def release_collection(
586587 check_status (response )
587588
588589 @retry_on_rpc_failure ()
590+ @retry_on_schema_mismatch ()
589591 async def insert_rows (
590592 self ,
591593 collection_name : str ,
@@ -596,26 +598,12 @@ async def insert_rows(
596598 ** kwargs ,
597599 ):
598600 await self .ensure_channel_ready ()
599- try :
600- request = await self ._prepare_row_insert_request (
601- collection_name , entities , partition_name , schema , timeout , ** kwargs
602- )
603- except DataNotMatchException :
604- schema = await self .update_schema (collection_name , timeout , ** kwargs )
605- request = await self ._prepare_row_insert_request (
606- collection_name , entities , partition_name , schema , timeout , ** kwargs
607- )
601+ request = await self ._prepare_row_insert_request (
602+ collection_name , entities , partition_name , schema , timeout , ** kwargs
603+ )
608604 resp = await self ._async_stub .Insert (
609605 request = request , timeout = timeout , metadata = _api_level_md (** kwargs )
610606 )
611- if resp .status .error_code == common_pb2 .SchemaMismatch :
612- schema = await self .update_schema (collection_name , timeout , ** kwargs )
613- request = await self ._prepare_row_insert_request (
614- collection_name , entities , partition_name , schema , timeout , ** kwargs
615- )
616- resp = await self ._async_stub .Insert (
617- request = request , timeout = timeout , metadata = _api_level_md (** kwargs )
618- )
619607 check_status (resp .status )
620608 ts_utils .update_collection_ts (collection_name , resp .timestamp )
621609 return MutationResult (resp )
@@ -632,7 +620,7 @@ async def _prepare_row_insert_request(
632620 if isinstance (entity_rows , dict ):
633621 entity_rows = [entity_rows ]
634622
635- schema , schema_timestamp = await self ._get_schema_from_cache_or_remote (
623+ schema , schema_timestamp = await self ._get_schema (
636624 collection_name , schema = schema , timeout = timeout , ** kwargs
637625 )
638626 fields_info = schema .get ("fields" )
@@ -700,7 +688,7 @@ async def _prepare_batch_upsert_request(
700688 partial_update = kwargs .get ("partial_update" , False )
701689
702690 schema = kwargs .get ("schema" )
703- schema , _ = await self ._get_schema_from_cache_or_remote (
691+ schema , _ = await self ._get_schema (
704692 collection_name , schema = schema , timeout = timeout , ** kwargs
705693 )
706694
@@ -770,6 +758,7 @@ async def _prepare_row_upsert_request(
770758 )
771759
772760 @retry_on_rpc_failure ()
761+ @retry_on_schema_mismatch ()
773762 async def upsert_rows (
774763 self ,
775764 collection_name : str ,
@@ -781,26 +770,12 @@ async def upsert_rows(
781770 await self .ensure_channel_ready ()
782771 if isinstance (entities , dict ):
783772 entities = [entities ]
784- try :
785- request = await self ._prepare_row_upsert_request (
786- collection_name , entities , partition_name , timeout , ** kwargs
787- )
788- except DataNotMatchException :
789- schema = await self .update_schema (collection_name , timeout , ** kwargs )
790- request = await self ._prepare_row_upsert_request (
791- collection_name , entities , partition_name , timeout , schema = schema , ** kwargs
792- )
773+ request = await self ._prepare_row_upsert_request (
774+ collection_name , entities , partition_name , timeout , ** kwargs
775+ )
793776 response = await self ._async_stub .Upsert (
794777 request , timeout = timeout , metadata = _api_level_md (** kwargs )
795778 )
796- if response .status .error_code == common_pb2 .SchemaMismatch :
797- schema = await self .update_schema (collection_name , timeout , ** kwargs )
798- request = await self ._prepare_row_upsert_request (
799- collection_name , entities , partition_name , timeout , schema = schema , ** kwargs
800- )
801- response = await self ._async_stub .Upsert (
802- request , timeout = timeout , metadata = _api_level_md (** kwargs )
803- )
804779 check_status (response .status )
805780 m = MutationResult (response )
806781 ts_utils .update_collection_ts (collection_name , m .timestamp )
0 commit comments