Skip to content

Commit 32c2e5f

Browse files
xiangyan99pvaneck
andauthored
Fix the answer query bug (#35259)
* Fix the answer query bug * Add tests * Update CHANGELOG.md * Update CHANGELOG.md * update * Update sdk/search/azure-search-documents/azure/search/documents/_utils.py Co-authored-by: Paul Van Eck <[email protected]> --------- Co-authored-by: Paul Van Eck <[email protected]>
1 parent 36819f0 commit 32c2e5f

File tree

6 files changed

+42
-9
lines changed

6 files changed

+42
-9
lines changed

sdk/search/azure-search-documents/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
### Bugs Fixed
1010

11+
- Fixed the bug that SearchClient failed when both answer count and answer threshold applied.
12+
1113
### Other Changes
1214

1315
## 11.6.0b3 (2024-04-09)

sdk/search/azure-search-documents/assets.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "python",
44
"TagPrefix": "python/search/azure-search-documents",
5-
"Tag": "python/search/azure-search-documents_a3db22f661"
5+
"Tag": "python/search/azure-search-documents_dd0ce4c420"
66
}

sdk/search/azure-search-documents/azure/search/documents/_search_client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ._paging import SearchItemPaged, SearchPageIterator
3535
from ._queries import AutocompleteQuery, SearchQuery, SuggestQuery
3636
from ._headers_mixin import HeadersMixin
37-
from ._utils import get_authentication_policy
37+
from ._utils import get_authentication_policy, get_answer_query
3838
from ._version import SDK_MONIKER
3939

4040

@@ -326,8 +326,7 @@ def search(
326326
filter_arg = filter
327327
search_fields_str = ",".join(search_fields) if search_fields else None
328328

329-
answers = query_answer if not query_answer_count else "{}|count-{}".format(query_answer, query_answer_count)
330-
answers = answers if not query_answer_threshold else "{}|threshold-{}".format(answers, query_answer_threshold)
329+
answers = get_answer_query(query_answer, query_answer_count, query_answer_threshold)
331330

332331
captions = (
333332
query_caption

sdk/search/azure-search-documents/azure/search/documents/_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,31 @@
33
# Licensed under the MIT License. See License.txt in the project root for
44
# license information.
55
# --------------------------------------------------------------------------
6-
from typing import Any, Optional
6+
from typing import Any, Optional, Union
77
from azure.core.pipeline.policies import (
88
BearerTokenCredentialPolicy,
99
AsyncBearerTokenCredentialPolicy,
1010
)
11+
from ._generated.models import QueryAnswerType
1112

1213
DEFAULT_AUDIENCE = "https://search.azure.com"
1314

1415

16+
def get_answer_query(
17+
query_answer: Optional[Union[str, QueryAnswerType]] = None,
18+
query_answer_count: Optional[int] = None,
19+
query_answer_threshold: Optional[float] = None,
20+
) -> Optional[Union[str, QueryAnswerType]]:
21+
answers = query_answer
22+
separator = "|"
23+
if query_answer_count:
24+
answers = f"{answers}{separator}count-{query_answer_count}"
25+
separator = ","
26+
if query_answer_threshold:
27+
answers = f"{answers}{separator}threshold-{query_answer_threshold}"
28+
return answers
29+
30+
1531
def is_retryable_status_code(status_code: Optional[int]) -> bool:
1632
if not status_code:
1733
return False

sdk/search/azure-search-documents/azure/search/documents/aio/_search_client_async.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from azure.core.credentials_async import AsyncTokenCredential
1111
from azure.core.tracing.decorator_async import distributed_trace_async
1212
from ._paging import AsyncSearchItemPaged, AsyncSearchPageIterator
13-
from .._utils import get_authentication_policy
13+
from .._utils import get_authentication_policy, get_answer_query
1414
from .._generated.aio import SearchIndexClient
1515
from .._generated.models import (
1616
AutocompleteMode,
@@ -329,8 +329,7 @@ async def search(
329329
include_total_result_count = include_total_count
330330
filter_arg = filter
331331
search_fields_str = ",".join(search_fields) if search_fields else None
332-
answers = query_answer if not query_answer_count else "{}|count-{}".format(query_answer, query_answer_count)
333-
answers = answers if not query_answer_threshold else "{}|threshold-{}".format(answers, query_answer_threshold)
332+
answers = get_answer_query(query_answer, query_answer_count, query_answer_threshold)
334333
captions = (
335334
query_caption
336335
if not query_caption_highlight_enabled

sdk/search/azure-search-documents/tests/test_search_client.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
RequestEntityTooLargeError,
2121
ApiVersion,
2222
)
23-
from azure.search.documents._utils import odata
23+
from azure.search.documents._utils import odata, get_answer_query
2424

2525
CREDENTIAL = AzureKeyCredential(key="test_api_key")
2626

@@ -56,6 +56,23 @@ def test_prevent_double_quoting(self):
5656
assert odata("foo eq '{foo}'", foo="a string") == "foo eq 'a string'"
5757

5858

59+
class TestAnswerQuery:
60+
def test_no_args(self):
61+
assert get_answer_query() is None
62+
63+
def test_query_answer(self):
64+
assert get_answer_query("query") == "query"
65+
66+
def test_query_answer_count(self):
67+
assert get_answer_query("query", 5) == "query|count-5"
68+
69+
def test_query_answer_threshold(self):
70+
assert get_answer_query("query", query_answer_threshold=0.5) == "query|threshold-0.5"
71+
72+
def test_query_answer_count_threshold(self):
73+
assert get_answer_query("query", 5, 0.5) == "query|count-5,threshold-0.5"
74+
75+
5976
class TestSearchClient:
6077
def test_init(self):
6178
client = SearchClient("endpoint", "index name", CREDENTIAL)

0 commit comments

Comments
 (0)