Skip to content

Commit f69e354

Browse files
XuanYang-cnclaude
andauthored
enhance: general performance improvements across MilvusClient path (#3279)
## Summary - **Fix duplicate `@retry_on_rpc_failure`** on `search()` in `grpc_handler.py` — was causing up to 75×75=5625 retry attempts instead of 75 - **24 micro-optimizations** across the MilvusClient/AsyncMilvusClient code path (no refactoring, no API changes): - Replace O(n²) bytes concatenation with `b"".join` and `struct.pack` batch (~136x faster at n=1000) - Use `orjson.loads` instead of stdlib `json.loads` (~6.7x faster) - Replace `deepcopy` with shallow `dict()` for search params (~4.2x faster) - Use `frozenset` for O(1) membership checks instead of list O(n) (~3.6x faster) - Use `WhichOneof` instead of cascading `HasField` in `len_of` - Use `time.monotonic` + lazy `traceback.format_exc` in error_handler (~13-17x faster) - Use builtin `isinstance(x, dict/list)` instead of `typing.Dict/List` (~3.3-6.5x faster) - Cache protobuf attribute chains to avoid repeated traversals (~1.3x faster) - Use dict dispatch tables instead of if/elif chains - Move per-call dict/function creation to module-level constants (~4.5x faster) - Use `template.copy()` instead of `dict.fromkeys` per iteration (~5.4x faster) - Skip dynamic field json comprehension when dynamic is disabled (~16x faster) - Consolidate `extend([item])` to `append` (~2.8x faster) - **80 benchmark tests** added in `tests/benchmark/test_perf_improvements.py` covering all optimization patterns - Net **-536 / +1562 lines** across 11 files (10 production + 1 benchmark test file) - All 2126 existing unit tests pass, all 80 benchmark tests pass ## Test plan - [x] `make lint` passes (black + ruff clean) - [x] `make unittest` — 2126 passed, 4 skipped, 1 xfailed - [x] `pytest tests/benchmark/test_perf_improvements.py` — 80 benchmarks pass - [ ] CI checks pass 🤖 Generated with [Claude Code](https://claude.com/claude-code) Signed-off-by: yangxuan <xuan.yang@zilliz.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 921e74b commit f69e354

File tree

11 files changed

+1562
-536
lines changed

11 files changed

+1562
-536
lines changed

pymilvus/client/async_grpc_handler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import asyncio
22
import base64
3-
import json
43
import socket
54
import time
65
from pathlib import Path
76
from typing import Dict, List, Optional, Tuple, Union
87
from urllib import parse
98

109
import grpc
10+
import orjson
1111
from grpc._cython import cygrpc
1212

1313
from pymilvus.client.call_context import CallContext, _api_level_md
@@ -1414,7 +1414,8 @@ async def query(
14141414
_, dynamic_fields = entity_helper.extract_dynamic_field_from_result(response)
14151415
keys = [field_data.field_name for field_data in response.fields_data]
14161416
filtered_keys = [k for k in keys if k != "$meta"]
1417-
results = [dict.fromkeys(filtered_keys) for _ in range(num_entities)]
1417+
template = dict.fromkeys(filtered_keys)
1418+
results = [template.copy() for _ in range(num_entities)]
14181419
lazy_field_data = []
14191420
for field_data in response.fields_data:
14201421
lazy_extracted = entity_helper.extract_row_data_from_fields_data_v2(field_data, results)
@@ -1620,7 +1621,7 @@ async def describe_index(
16201621
info_dict["field_name"] = response.index_descriptions[0].field_name
16211622
info_dict["index_name"] = response.index_descriptions[0].index_name
16221623
if info_dict.get("params"):
1623-
info_dict["params"] = json.loads(info_dict["params"])
1624+
info_dict["params"] = orjson.loads(info_dict["params"])
16241625
info_dict["total_rows"] = response.index_descriptions[0].total_rows
16251626
info_dict["indexed_rows"] = response.index_descriptions[0].indexed_rows
16261627
info_dict["pending_index_rows"] = response.index_descriptions[0].pending_index_rows

pymilvus/client/entity_helper.py

Lines changed: 168 additions & 237 deletions
Large diffs are not rendered by default.

pymilvus/client/grpc_handler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import base64
2-
import json
32
import logging
43
import socket
54
import threading
@@ -9,6 +8,7 @@
98
from urllib import parse
109

1110
import grpc
11+
import orjson
1212
from grpc._cython import cygrpc
1313

1414
from pymilvus.client.call_context import CallContext, _api_level_md
@@ -1221,7 +1221,6 @@ def _execute_hybrid_search(
12211221
return SearchFuture(None, None, e)
12221222
raise e from e
12231223

1224-
@retry_on_rpc_failure()
12251224
@retry_on_rpc_failure()
12261225
def search(
12271226
self,
@@ -1590,7 +1589,7 @@ def describe_index(
15901589
info_dict["field_name"] = response.index_descriptions[0].field_name
15911590
info_dict["index_name"] = response.index_descriptions[0].index_name
15921591
if info_dict.get("params"):
1593-
info_dict["params"] = json.loads(info_dict["params"])
1592+
info_dict["params"] = orjson.loads(info_dict["params"])
15941593
info_dict["total_rows"] = response.index_descriptions[0].total_rows
15951594
info_dict["indexed_rows"] = response.index_descriptions[0].indexed_rows
15961595
info_dict["pending_index_rows"] = response.index_descriptions[0].pending_index_rows
@@ -2201,7 +2200,8 @@ def query(
22012200

22022201
keys = [field_data.field_name for field_data in response.fields_data]
22032202
filtered_keys = [k for k in keys if k != "$meta"]
2204-
results = [dict.fromkeys(filtered_keys) for _ in range(num_entities)]
2203+
template = dict.fromkeys(filtered_keys)
2204+
results = [template.copy() for _ in range(num_entities)]
22052205
lazy_field_data = []
22062206
for field_data in response.fields_data:
22072207
lazy_extracted = entity_helper.extract_row_data_from_fields_data_v2(field_data, results)

pymilvus/client/prepare.py

Lines changed: 34 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,16 @@
5757
)
5858
from .utils import get_params, traverse_info, traverse_upsert_info
5959

60+
_JSON_TYPE_MAP = {
61+
DataType.INT8: "Int8",
62+
DataType.INT16: "Int16",
63+
DataType.INT32: "Int32",
64+
DataType.INT64: "Int64",
65+
DataType.BOOL: "Bool",
66+
DataType.VARCHAR: "VarChar",
67+
DataType.STRING: "VarChar",
68+
}
69+
6070

6171
class Prepare:
6272
@classmethod
@@ -558,7 +568,7 @@ def _function_output_field_names(fields_info: List[Dict]):
558568

559569
@staticmethod
560570
def _num_input_fields(fields_info: List[Dict], is_upsert: bool):
561-
return len([field for field in fields_info if Prepare._is_input_field(field, is_upsert)])
571+
return sum(1 for field in fields_info if Prepare._is_input_field(field, is_upsert))
562572

563573
@staticmethod
564574
def _process_struct_field(
@@ -784,7 +794,7 @@ def _parse_row_request(
784794

785795
try:
786796
for entity in entities:
787-
if not isinstance(entity, Dict):
797+
if not isinstance(entity, dict):
788798
msg = f"expected Dict, got '{type(entity).__name__}'"
789799
raise TypeError(msg)
790800
for k, v in entity.items():
@@ -838,13 +848,12 @@ def _parse_row_request(
838848
raise DataNotMatchException(
839849
message=ExceptionsMessage.InsertMissedField % key
840850
)
841-
json_dict = {
842-
k: v
843-
for k, v in entity.items()
844-
if k not in fields_data and k not in struct_fields_data and enable_dynamic
845-
}
846-
847851
if enable_dynamic:
852+
json_dict = {
853+
k: v
854+
for k, v in entity.items()
855+
if k not in fields_data and k not in struct_fields_data
856+
}
848857
json_value = entity_helper.convert_to_json(json_dict)
849858
d_field.scalars.json_data.data.append(json_value)
850859

@@ -939,7 +948,7 @@ def _parse_upsert_row_request(
939948

940949
try:
941950
for entity in entities:
942-
if not isinstance(entity, Dict):
951+
if not isinstance(entity, dict):
943952
msg = f"expected Dict, got '{type(entity).__name__}'"
944953
raise TypeError(msg)
945954
for k, v in entity.items():
@@ -999,13 +1008,12 @@ def _parse_upsert_row_request(
9991008
raise DataNotMatchException(
10001009
message=ExceptionsMessage.InsertMissedField % key
10011010
)
1002-
json_dict = {
1003-
k: v
1004-
for k, v in entity.items()
1005-
if k not in fields_data and k not in struct_fields_data and enable_dynamic
1006-
}
1007-
10081011
if enable_dynamic:
1012+
json_dict = {
1013+
k: v
1014+
for k, v in entity.items()
1015+
if k not in fields_data and k not in struct_fields_data
1016+
}
10091017
json_value = entity_helper.convert_to_json(json_dict)
10101018
d_field.scalars.json_data.data.append(json_value)
10111019
field_len[DYNAMIC_FIELD_NAME] += 1
@@ -1049,7 +1057,7 @@ def _parse_upsert_row_request(
10491057
)
10501058
request.fields_data.extend(struct_fields_data.values())
10511059

1052-
for _, field in enumerate(input_fields_info):
1060+
for field in input_fields_info:
10531061
is_dynamic = False
10541062
field_name = field["name"]
10551063

@@ -1551,20 +1559,10 @@ def search_requests_with_expr(
15511559

15521560
json_type = kwargs.get(JSON_TYPE)
15531561
if json_type is not None:
1554-
if json_type == DataType.INT8:
1555-
search_params[JSON_TYPE] = "Int8"
1556-
elif json_type == DataType.INT16:
1557-
search_params[JSON_TYPE] = "Int16"
1558-
elif json_type == DataType.INT32:
1559-
search_params[JSON_TYPE] = "Int32"
1560-
elif json_type == DataType.INT64:
1561-
search_params[JSON_TYPE] = "Int64"
1562-
elif json_type == DataType.BOOL:
1563-
search_params[JSON_TYPE] = "Bool"
1564-
elif json_type in (DataType.VARCHAR, DataType.STRING):
1565-
search_params[JSON_TYPE] = "VarChar"
1566-
else:
1562+
json_type_name = _JSON_TYPE_MAP.get(json_type)
1563+
if json_type_name is None:
15671564
raise ParamError(message=f"Unsupported json cast type: {json_type}")
1565+
search_params[JSON_TYPE] = json_type_name
15681566

15691567
strict_cast = kwargs.get(STRICT_CAST)
15701568
if strict_cast is not None:
@@ -1708,41 +1706,15 @@ def hybrid_search_request_with_ranker(
17081706
]
17091707
)
17101708

1711-
if kwargs.get(RANK_GROUP_SCORER) is not None:
1712-
request.rank_params.extend(
1713-
[
1714-
common_types.KeyValuePair(
1715-
key=RANK_GROUP_SCORER, value=kwargs.get(RANK_GROUP_SCORER)
1716-
)
1717-
]
1718-
)
1719-
1720-
if kwargs.get(GROUP_BY_FIELD) is not None:
1721-
request.rank_params.extend(
1722-
[
1723-
common_types.KeyValuePair(
1724-
key=GROUP_BY_FIELD, value=utils.dumps(kwargs.get(GROUP_BY_FIELD))
1725-
)
1726-
]
1727-
)
1728-
1729-
if kwargs.get(GROUP_SIZE) is not None:
1730-
request.rank_params.extend(
1731-
[
1732-
common_types.KeyValuePair(
1733-
key=GROUP_SIZE, value=utils.dumps(kwargs.get(GROUP_SIZE))
1734-
)
1735-
]
1736-
)
1737-
1738-
if kwargs.get(STRICT_GROUP_SIZE) is not None:
1739-
request.rank_params.extend(
1740-
[
1709+
for param_key in (RANK_GROUP_SCORER, GROUP_BY_FIELD, GROUP_SIZE, STRICT_GROUP_SIZE):
1710+
val = kwargs.get(param_key)
1711+
if val is not None:
1712+
request.rank_params.append(
17411713
common_types.KeyValuePair(
1742-
key=STRICT_GROUP_SIZE, value=utils.dumps(kwargs.get(STRICT_GROUP_SIZE))
1714+
key=param_key,
1715+
value=val if param_key == RANK_GROUP_SCORER else utils.dumps(val),
17431716
)
1744-
]
1745-
)
1717+
)
17461718

17471719
if isinstance(rerank, Function):
17481720
request.function_score.CopyFrom(Prepare.ranker_to_function_score(rerank))

0 commit comments

Comments
 (0)