Skip to content

Commit 74fce30

Browse files
XuanYang-cnclaude
andauthored
enhance: general performance improvements across MilvusClient path (#3280)
- **Fix duplicate `@retry_on_rpc_failure`** on `search()` in `grpc_handler.py` — was causing up to 75×75=5625 retry attempts instead of 75 - **24 micro-optimizations** across the MilvusClient/AsyncMilvusClient code path (no refactoring, no API changes): - Replace O(n²) bytes concatenation with `b"".join` and `struct.pack` batch (~136x faster at n=1000) - Use `orjson.loads` instead of stdlib `json.loads` (~6.7x faster) - Replace `deepcopy` with shallow `dict()` for search params (~4.2x faster) - Use `frozenset` for O(1) membership checks instead of list O(n) (~3.6x faster) - Use `WhichOneof` instead of cascading `HasField` in `len_of` - Use `time.monotonic` + lazy `traceback.format_exc` in error_handler (~13-17x faster) - Use builtin `isinstance(x, dict/list)` instead of `typing.Dict/List` (~3.3-6.5x faster) - Cache protobuf attribute chains to avoid repeated traversals (~1.3x faster) - Use dict dispatch tables instead of if/elif chains - Move per-call dict/function creation to module-level constants (~4.5x faster) - Use `template.copy()` instead of `dict.fromkeys` per iteration (~5.4x faster) - Skip dynamic field json comprehension when dynamic is disabled (~16x faster) - Consolidate `extend([item])` to `append` (~2.8x faster) - **80 benchmark tests** added in `tests/benchmark/test_perf_improvements.py` covering all optimization patterns - Net **-536 / +1562 lines** across 11 files (10 production + 1 benchmark test file) - All 2126 existing unit tests pass, all 80 benchmark tests pass - [x] `make lint` passes (black + ruff clean) - [x] `make unittest` — 2126 passed, 4 skipped, 1 xfailed - [x] `pytest tests/benchmark/test_perf_improvements.py` — 80 benchmarks pass (cherry picked from commit f69e354) --------- Signed-off-by: yangxuan <xuan.yang@zilliz.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 91e5222 commit 74fce30

File tree

11 files changed

+1557
-522
lines changed

11 files changed

+1557
-522
lines changed

pymilvus/client/async_grpc_handler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import asyncio
22
import base64
3-
import json
43
import socket
54
import time
65
from pathlib import Path
76
from typing import Dict, List, Optional, Tuple, Union
87
from urllib import parse
98

109
import grpc
10+
import orjson
1111
from grpc._cython import cygrpc
1212

1313
from pymilvus.client.call_context import CallContext, _api_level_md
@@ -1410,7 +1410,8 @@ async def query(
14101410
_, dynamic_fields = entity_helper.extract_dynamic_field_from_result(response)
14111411
keys = [field_data.field_name for field_data in response.fields_data]
14121412
filtered_keys = [k for k in keys if k != "$meta"]
1413-
results = [dict.fromkeys(filtered_keys) for _ in range(num_entities)]
1413+
template = dict.fromkeys(filtered_keys)
1414+
results = [template.copy() for _ in range(num_entities)]
14141415
lazy_field_data = []
14151416
for field_data in response.fields_data:
14161417
lazy_extracted = entity_helper.extract_row_data_from_fields_data_v2(field_data, results)
@@ -1616,7 +1617,7 @@ async def describe_index(
16161617
info_dict["field_name"] = response.index_descriptions[0].field_name
16171618
info_dict["index_name"] = response.index_descriptions[0].index_name
16181619
if info_dict.get("params"):
1619-
info_dict["params"] = json.loads(info_dict["params"])
1620+
info_dict["params"] = orjson.loads(info_dict["params"])
16201621
info_dict["total_rows"] = response.index_descriptions[0].total_rows
16211622
info_dict["indexed_rows"] = response.index_descriptions[0].indexed_rows
16221623
info_dict["pending_index_rows"] = response.index_descriptions[0].pending_index_rows

pymilvus/client/entity_helper.py

Lines changed: 163 additions & 223 deletions
Large diffs are not rendered by default.

pymilvus/client/grpc_handler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import base64
2-
import json
32
import logging
43
import socket
54
import threading
@@ -9,6 +8,7 @@
98
from urllib import parse
109

1110
import grpc
11+
import orjson
1212
from grpc._cython import cygrpc
1313

1414
from pymilvus.client.call_context import CallContext, _api_level_md
@@ -1162,7 +1162,6 @@ def _execute_hybrid_search(
11621162
return SearchFuture(None, None, e)
11631163
raise e from e
11641164

1165-
@retry_on_rpc_failure()
11661165
@retry_on_rpc_failure()
11671166
def search(
11681167
self,
@@ -1531,7 +1530,7 @@ def describe_index(
15311530
info_dict["field_name"] = response.index_descriptions[0].field_name
15321531
info_dict["index_name"] = response.index_descriptions[0].index_name
15331532
if info_dict.get("params"):
1534-
info_dict["params"] = json.loads(info_dict["params"])
1533+
info_dict["params"] = orjson.loads(info_dict["params"])
15351534
info_dict["total_rows"] = response.index_descriptions[0].total_rows
15361535
info_dict["indexed_rows"] = response.index_descriptions[0].indexed_rows
15371536
info_dict["pending_index_rows"] = response.index_descriptions[0].pending_index_rows
@@ -2142,7 +2141,8 @@ def query(
21422141

21432142
keys = [field_data.field_name for field_data in response.fields_data]
21442143
filtered_keys = [k for k in keys if k != "$meta"]
2145-
results = [dict.fromkeys(filtered_keys) for _ in range(num_entities)]
2144+
template = dict.fromkeys(filtered_keys)
2145+
results = [template.copy() for _ in range(num_entities)]
21462146
lazy_field_data = []
21472147
for field_data in response.fields_data:
21482148
lazy_extracted = entity_helper.extract_row_data_from_fields_data_v2(field_data, results)

pymilvus/client/prepare.py

Lines changed: 34 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,16 @@
5656
)
5757
from .utils import get_params, traverse_info, traverse_upsert_info
5858

59+
_JSON_TYPE_MAP = {
60+
DataType.INT8: "Int8",
61+
DataType.INT16: "Int16",
62+
DataType.INT32: "Int32",
63+
DataType.INT64: "Int64",
64+
DataType.BOOL: "Bool",
65+
DataType.VARCHAR: "VarChar",
66+
DataType.STRING: "VarChar",
67+
}
68+
5969

6070
class Prepare:
6171
@classmethod
@@ -557,7 +567,7 @@ def _function_output_field_names(fields_info: List[Dict]):
557567

558568
@staticmethod
559569
def _num_input_fields(fields_info: List[Dict], is_upsert: bool):
560-
return len([field for field in fields_info if Prepare._is_input_field(field, is_upsert)])
570+
return sum(1 for field in fields_info if Prepare._is_input_field(field, is_upsert))
561571

562572
@staticmethod
563573
def _process_struct_field(
@@ -783,7 +793,7 @@ def _parse_row_request(
783793

784794
try:
785795
for entity in entities:
786-
if not isinstance(entity, Dict):
796+
if not isinstance(entity, dict):
787797
msg = f"expected Dict, got '{type(entity).__name__}'"
788798
raise TypeError(msg)
789799
for k, v in entity.items():
@@ -837,13 +847,12 @@ def _parse_row_request(
837847
raise DataNotMatchException(
838848
message=ExceptionsMessage.InsertMissedField % key
839849
)
840-
json_dict = {
841-
k: v
842-
for k, v in entity.items()
843-
if k not in fields_data and k not in struct_fields_data and enable_dynamic
844-
}
845-
846850
if enable_dynamic:
851+
json_dict = {
852+
k: v
853+
for k, v in entity.items()
854+
if k not in fields_data and k not in struct_fields_data
855+
}
847856
json_value = entity_helper.convert_to_json(json_dict)
848857
d_field.scalars.json_data.data.append(json_value)
849858

@@ -938,7 +947,7 @@ def _parse_upsert_row_request(
938947

939948
try:
940949
for entity in entities:
941-
if not isinstance(entity, Dict):
950+
if not isinstance(entity, dict):
942951
msg = f"expected Dict, got '{type(entity).__name__}'"
943952
raise TypeError(msg)
944953
for k, v in entity.items():
@@ -998,13 +1007,12 @@ def _parse_upsert_row_request(
9981007
raise DataNotMatchException(
9991008
message=ExceptionsMessage.InsertMissedField % key
10001009
)
1001-
json_dict = {
1002-
k: v
1003-
for k, v in entity.items()
1004-
if k not in fields_data and k not in struct_fields_data and enable_dynamic
1005-
}
1006-
10071010
if enable_dynamic:
1011+
json_dict = {
1012+
k: v
1013+
for k, v in entity.items()
1014+
if k not in fields_data and k not in struct_fields_data
1015+
}
10081016
json_value = entity_helper.convert_to_json(json_dict)
10091017
d_field.scalars.json_data.data.append(json_value)
10101018
field_len[DYNAMIC_FIELD_NAME] += 1
@@ -1048,7 +1056,7 @@ def _parse_upsert_row_request(
10481056
)
10491057
request.fields_data.extend(struct_fields_data.values())
10501058

1051-
for _, field in enumerate(input_fields_info):
1059+
for field in input_fields_info:
10521060
is_dynamic = False
10531061
field_name = field["name"]
10541062

@@ -1546,20 +1554,10 @@ def search_requests_with_expr(
15461554

15471555
json_type = kwargs.get(JSON_TYPE)
15481556
if json_type is not None:
1549-
if json_type == DataType.INT8:
1550-
search_params[JSON_TYPE] = "Int8"
1551-
elif json_type == DataType.INT16:
1552-
search_params[JSON_TYPE] = "Int16"
1553-
elif json_type == DataType.INT32:
1554-
search_params[JSON_TYPE] = "Int32"
1555-
elif json_type == DataType.INT64:
1556-
search_params[JSON_TYPE] = "Int64"
1557-
elif json_type == DataType.BOOL:
1558-
search_params[JSON_TYPE] = "Bool"
1559-
elif json_type in (DataType.VARCHAR, DataType.STRING):
1560-
search_params[JSON_TYPE] = "VarChar"
1561-
else:
1557+
json_type_name = _JSON_TYPE_MAP.get(json_type)
1558+
if json_type_name is None:
15621559
raise ParamError(message=f"Unsupported json cast type: {json_type}")
1560+
search_params[JSON_TYPE] = json_type_name
15631561

15641562
strict_cast = kwargs.get(STRICT_CAST)
15651563
if strict_cast is not None:
@@ -1699,41 +1697,15 @@ def hybrid_search_request_with_ranker(
16991697
]
17001698
)
17011699

1702-
if kwargs.get(RANK_GROUP_SCORER) is not None:
1703-
request.rank_params.extend(
1704-
[
1705-
common_types.KeyValuePair(
1706-
key=RANK_GROUP_SCORER, value=kwargs.get(RANK_GROUP_SCORER)
1707-
)
1708-
]
1709-
)
1710-
1711-
if kwargs.get(GROUP_BY_FIELD) is not None:
1712-
request.rank_params.extend(
1713-
[
1714-
common_types.KeyValuePair(
1715-
key=GROUP_BY_FIELD, value=utils.dumps(kwargs.get(GROUP_BY_FIELD))
1716-
)
1717-
]
1718-
)
1719-
1720-
if kwargs.get(GROUP_SIZE) is not None:
1721-
request.rank_params.extend(
1722-
[
1723-
common_types.KeyValuePair(
1724-
key=GROUP_SIZE, value=utils.dumps(kwargs.get(GROUP_SIZE))
1725-
)
1726-
]
1727-
)
1728-
1729-
if kwargs.get(STRICT_GROUP_SIZE) is not None:
1730-
request.rank_params.extend(
1731-
[
1700+
for param_key in (RANK_GROUP_SCORER, GROUP_BY_FIELD, GROUP_SIZE, STRICT_GROUP_SIZE):
1701+
val = kwargs.get(param_key)
1702+
if val is not None:
1703+
request.rank_params.append(
17321704
common_types.KeyValuePair(
1733-
key=STRICT_GROUP_SIZE, value=utils.dumps(kwargs.get(STRICT_GROUP_SIZE))
1705+
key=param_key,
1706+
value=val if param_key == RANK_GROUP_SCORER else utils.dumps(val),
17341707
)
1735-
]
1736-
)
1708+
)
17371709

17381710
if isinstance(rerank, Function):
17391711
request.function_score.CopyFrom(Prepare.ranker_to_function_score(rerank))

0 commit comments

Comments
 (0)