From cd3beda7402fcf77fe251385d394cb9088361b9a Mon Sep 17 00:00:00 2001 From: kukushking Date: Thu, 2 Jan 2025 17:33:10 +0000 Subject: [PATCH 1/2] return opensearch aggregation hits --- awswrangler/opensearch/_read.py | 34 ++++++++++++++++++++++++++------- tests/unit/test_opensearch.py | 26 +++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/awswrangler/opensearch/_read.py b/awswrangler/opensearch/_read.py index d1ead8418..283f1f9d7 100644 --- a/awswrangler/opensearch/_read.py +++ b/awswrangler/opensearch/_read.py @@ -41,12 +41,24 @@ def _hit_to_row(hit: Mapping[str, Any]) -> Mapping[str, Any]: return row -def _search_response_to_documents(response: Mapping[str, Any]) -> list[Mapping[str, Any]]: - return [_hit_to_row(hit) for hit in response.get("hits", {}).get("hits", [])] - - -def _search_response_to_df(response: Mapping[str, Any] | Any) -> pd.DataFrame: - return pd.DataFrame(_search_response_to_documents(response)) +def _search_response_to_documents( + response: Mapping[str, Any], aggregations: list[str] | None = None +) -> list[Mapping[str, Any]]: + hits = response.get("hits", {}).get("hits", []) + if not hits and aggregations: + hits = [ + aggregation_hit + for aggregation_name in aggregations + for aggregation_hit in response.get("aggregations", {}) + .get(aggregation_name, {}) + .get("hits", {}) + .get("hits", []) + ] + return [_hit_to_row(hit) for hit in hits] + + +def _search_response_to_df(response: Mapping[str, Any] | Any, aggregations: list[str] | None = None) -> pd.DataFrame: + return pd.DataFrame(_search_response_to_documents(response=response, aggregations=aggregations)) @_utils.check_optional_dependency(opensearchpy, "opensearchpy") @@ -128,8 +140,16 @@ def search( documents = [_hit_to_row(doc) for doc in documents_generator] df = pd.DataFrame(documents) else: + aggregations = ( + list(search_body.get("aggregations", {}).keys() or search_body.get("aggs", {}).keys()) + if search_body + else None + ) response = client.search(index=index, body=search_body, filter_path=filter_path, **kwargs) - df = _search_response_to_df(response) + df = _search_response_to_df( + response=response, + aggregations=aggregations, + ) return df diff --git a/tests/unit/test_opensearch.py b/tests/unit/test_opensearch.py index 3fbc8293a..7dada866b 100644 --- a/tests/unit/test_opensearch.py +++ b/tests/unit/test_opensearch.py @@ -424,6 +424,32 @@ def test_search_scroll(client): wr.opensearch.delete_index(client, index) +def test_search_aggregation(client): + index = f"test_search_agg_{_get_unique_suffix()}" + kwargs = {} if _is_serverless(client) else {"refresh": "wait_for"} + try: + wr.opensearch.index_documents( + client, + documents=inspections_documents, + index=index, + id_keys=["inspection_id"], + **kwargs, + ) + if _is_serverless(client): + # The refresh interval for OpenSearch Serverless is between 10 and 30 seconds + # depending on the size of the request. + time.sleep(30) + df = wr.opensearch.search( + client, + index=index, + search_body={"aggregations": {"top_hits_inspections": {"top_hits": {"size": 2}}}}, + filter_path=["aggregations"], + ) + assert df.shape[0] == 2 + finally: + wr.opensearch.delete_index(client, index) + + @pytest.mark.parametrize("fetch_size", [None, 1000, 10000]) @pytest.mark.parametrize("fetch_size_param_name", ["size", "fetch_size"]) def test_search_sql(client, fetch_size, fetch_size_param_name): From 9c566589c040208245dedd6c84d93b88d6bc95a5 Mon Sep 17 00:00:00 2001 From: kukushking Date: Fri, 3 Jan 2025 17:21:15 +0000 Subject: [PATCH 2/2] add _aggregation_name --- awswrangler/opensearch/_read.py | 2 +- tests/unit/test_opensearch.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/awswrangler/opensearch/_read.py b/awswrangler/opensearch/_read.py index 283f1f9d7..b4356411c 100644 --- a/awswrangler/opensearch/_read.py +++ b/awswrangler/opensearch/_read.py @@ -47,7 +47,7 @@ def _search_response_to_documents( hits = response.get("hits", {}).get("hits", []) if not hits and aggregations: hits = [ - aggregation_hit + dict(aggregation_hit, _aggregation_name=aggregation_name) for aggregation_name in aggregations for aggregation_hit in response.get("aggregations", {}) .get(aggregation_name, {}) diff --git a/tests/unit/test_opensearch.py b/tests/unit/test_opensearch.py index 7dada866b..422ccc8c6 100644 --- a/tests/unit/test_opensearch.py +++ b/tests/unit/test_opensearch.py @@ -442,10 +442,19 @@ def test_search_aggregation(client): df = wr.opensearch.search( client, index=index, - search_body={"aggregations": {"top_hits_inspections": {"top_hits": {"size": 2}}}}, + search_body={ + "aggregations": { + "latest_inspections": {"top_hits": {"sort": [{"inspection_date": {"order": "asc"}}], "size": 1}}, + "lowest_inspection_score": { + "top_hits": {"sort": [{"inspection_score": {"order": "asc"}}], "size": 1} + }, + } + }, filter_path=["aggregations"], ) assert df.shape[0] == 2 + assert len(df.loc[df["_aggregation_name"] == "latest_inspections"]) == 1 + assert len(df.loc[df["_aggregation_name"] == "lowest_inspection_score"]) == 1 finally: wr.opensearch.delete_index(client, index)