|
| 1 | +"""Amazon OpenSearch Read Module (PRIVATE).""" |
| 2 | + |
| 3 | +from typing import Any, Collection, Dict, List, Mapping, Optional, Union |
| 4 | + |
| 5 | +import pandas as pd |
| 6 | +from opensearchpy import OpenSearch |
| 7 | +from opensearchpy.helpers import scan |
| 8 | + |
| 9 | +from awswrangler.opensearch._utils import _get_distribution |
| 10 | + |
| 11 | + |
| 12 | +def _resolve_fields(row: Mapping[str, Any]) -> Mapping[str, Any]: |
| 13 | + fields = {} |
| 14 | + for field in row: |
| 15 | + if isinstance(row[field], dict): |
| 16 | + nested_fields = _resolve_fields(row[field]) |
| 17 | + for n_field, val in nested_fields.items(): |
| 18 | + fields[f"{field}.{n_field}"] = val |
| 19 | + else: |
| 20 | + fields[field] = row[field] |
| 21 | + return fields |
| 22 | + |
| 23 | + |
| 24 | +def _hit_to_row(hit: Mapping[str, Any]) -> Mapping[str, Any]: |
| 25 | + row: Dict[str, Any] = {} |
| 26 | + for k in hit.keys(): |
| 27 | + if k == "_source": |
| 28 | + solved_fields = _resolve_fields(hit["_source"]) |
| 29 | + row.update(solved_fields) |
| 30 | + elif k.startswith("_"): |
| 31 | + row[k] = hit[k] |
| 32 | + return row |
| 33 | + |
| 34 | + |
| 35 | +def _search_response_to_documents(response: Mapping[str, Any]) -> List[Mapping[str, Any]]: |
| 36 | + return [_hit_to_row(hit) for hit in response["hits"]["hits"]] |
| 37 | + |
| 38 | + |
| 39 | +def _search_response_to_df(response: Union[Mapping[str, Any], Any]) -> pd.DataFrame: |
| 40 | + return pd.DataFrame(_search_response_to_documents(response)) |
| 41 | + |
| 42 | + |
| 43 | +def search( |
| 44 | + client: OpenSearch, |
| 45 | + index: Optional[str] = "_all", |
| 46 | + search_body: Optional[Dict[str, Any]] = None, |
| 47 | + doc_type: Optional[str] = None, |
| 48 | + is_scroll: Optional[bool] = False, |
| 49 | + filter_path: Optional[Union[str, Collection[str]]] = None, |
| 50 | + **kwargs: Any, |
| 51 | +) -> pd.DataFrame: |
| 52 | + """Return results matching query DSL as pandas dataframe. |
| 53 | +
|
| 54 | + Parameters |
| 55 | + ---------- |
| 56 | + client : OpenSearch |
| 57 | + instance of opensearchpy.OpenSearch to use. |
| 58 | + index : str, optional |
| 59 | + A comma-separated list of index names to search. |
| 60 | + use `_all` or empty string to perform the operation on all indices. |
| 61 | + search_body : Dict[str, Any], optional |
| 62 | + The search definition using the [Query DSL](https://opensearch.org/docs/opensearch/query-dsl/full-text/). |
| 63 | + doc_type : str, optional |
| 64 | + Name of the document type (for Elasticsearch versions 5.x and earlier). |
| 65 | + is_scroll : bool, optional |
| 66 | + Allows to retrieve a large numbers of results from a single search request using |
| 67 | + [scroll](https://opensearch.org/docs/opensearch/rest-api/scroll/) |
| 68 | + for example, for machine learning jobs. |
| 69 | + Because scroll search contexts consume a lot of memory, we suggest you don’t use the scroll operation |
| 70 | + for frequent user queries. |
| 71 | + filter_path : Union[str, Collection[str]], optional |
| 72 | + Use the filter_path parameter to reduce the size of the OpenSearch Service response \ |
| 73 | +(default: ['hits.hits._id','hits.hits._source']) |
| 74 | + **kwargs : |
| 75 | + KEYWORD arguments forwarded to [opensearchpy.OpenSearch.search]\ |
| 76 | +(https://opensearch-py.readthedocs.io/en/latest/api.html#opensearchpy.OpenSearch.search) |
| 77 | + and also to [opensearchpy.helpers.scan](https://opensearch-py.readthedocs.io/en/master/helpers.html#scan) |
| 78 | + if `is_scroll=True` |
| 79 | +
|
| 80 | + Returns |
| 81 | + ------- |
| 82 | + Union[pandas.DataFrame, Iterator[pandas.DataFrame]] |
| 83 | + Results as Pandas DataFrame |
| 84 | +
|
| 85 | + Examples |
| 86 | + -------- |
| 87 | + Searching an index using query DSL |
| 88 | +
|
| 89 | + >>> import awswrangler as wr |
| 90 | + >>> client = wr.opensearch.connect(host='DOMAIN-ENDPOINT') |
| 91 | + >>> df = wr.opensearch.search( |
| 92 | + ... client=client, |
| 93 | + ... index='movies', |
| 94 | + ... search_body={ |
| 95 | + ... "query": { |
| 96 | + ... "match": { |
| 97 | + ... "title": "wind" |
| 98 | + ... } |
| 99 | + ... } |
| 100 | + ... } |
| 101 | + ... ) |
| 102 | +
|
| 103 | +
|
| 104 | + """ |
| 105 | + if doc_type: |
| 106 | + kwargs["doc_type"] = doc_type |
| 107 | + |
| 108 | + if filter_path is None: |
| 109 | + filter_path = ["hits.hits._id", "hits.hits._source"] |
| 110 | + |
| 111 | + if is_scroll: |
| 112 | + if isinstance(filter_path, str): |
| 113 | + filter_path = [filter_path] |
| 114 | + filter_path = ["_scroll_id", "_shards"] + list(filter_path) # required for scroll |
| 115 | + documents_generator = scan(client, index=index, query=search_body, filter_path=filter_path, **kwargs) |
| 116 | + documents = [_hit_to_row(doc) for doc in documents_generator] |
| 117 | + df = pd.DataFrame(documents) |
| 118 | + else: |
| 119 | + response = client.search(index=index, body=search_body, filter_path=filter_path, **kwargs) |
| 120 | + df = _search_response_to_df(response) |
| 121 | + return df |
| 122 | + |
| 123 | + |
| 124 | +def search_by_sql(client: OpenSearch, sql_query: str, **kwargs: Any) -> pd.DataFrame: |
| 125 | + """Return results matching [SQL query](https://opensearch.org/docs/search-plugins/sql/index/) as pandas dataframe. |
| 126 | +
|
| 127 | + Parameters |
| 128 | + ---------- |
| 129 | + client : OpenSearch |
| 130 | + instance of opensearchpy.OpenSearch to use. |
| 131 | + sql_query : str |
| 132 | + SQL query |
| 133 | + **kwargs : |
| 134 | + KEYWORD arguments forwarded to request url (e.g.: filter_path, etc.) |
| 135 | +
|
| 136 | + Returns |
| 137 | + ------- |
| 138 | + Union[pandas.DataFrame, Iterator[pandas.DataFrame]] |
| 139 | + Results as Pandas DataFrame |
| 140 | +
|
| 141 | + Examples |
| 142 | + -------- |
| 143 | + Searching an index using SQL query |
| 144 | +
|
| 145 | + >>> import awswrangler as wr |
| 146 | + >>> client = wr.opensearch.connect(host='DOMAIN-ENDPOINT') |
| 147 | + >>> df = wr.opensearch.search_by_sql( |
| 148 | + >>> client=client, |
| 149 | + >>> sql_query='SELECT * FROM my-index LIMIT 50' |
| 150 | + >>> ) |
| 151 | +
|
| 152 | +
|
| 153 | + """ |
| 154 | + if _get_distribution(client) == "opensearch": |
| 155 | + url = "/_plugins/_sql" |
| 156 | + else: |
| 157 | + url = "/_opendistro/_sql" |
| 158 | + |
| 159 | + kwargs["format"] = "json" |
| 160 | + body = {"query": sql_query} |
| 161 | + for size_att in ["size", "fetch_size"]: |
| 162 | + if size_att in kwargs: |
| 163 | + body["fetch_size"] = kwargs[size_att] |
| 164 | + del kwargs[size_att] # unrecognized parameter |
| 165 | + response = client.transport.perform_request( |
| 166 | + "POST", url, headers={"Content-Type": "application/json"}, body=body, params=kwargs |
| 167 | + ) |
| 168 | + df = _search_response_to_df(response) |
| 169 | + return df |
0 commit comments