From 698c1b26a460bd20a224d040892e3eebfb9a6bae Mon Sep 17 00:00:00 2001 From: Petro Kostiukevych Date: Wed, 15 Jul 2020 12:44:05 +0300 Subject: [PATCH 1/3] implement scan composite aggregation --- aioelasticsearch/helpers.py | 150 ++++++++++++++++++- tests/test_composite_scan.py | 275 +++++++++++++++++++++++++++++++++++ 2 files changed, 423 insertions(+), 2 deletions(-) create mode 100644 tests/test_composite_scan.py diff --git a/aioelasticsearch/helpers.py b/aioelasticsearch/helpers.py index a4e19e5e..12b86cf9 100644 --- a/aioelasticsearch/helpers.py +++ b/aioelasticsearch/helpers.py @@ -1,10 +1,12 @@ +import asyncio import logging +from copy import deepcopy from elasticsearch.helpers import ScanError -from aioelasticsearch import NotFoundError +from aioelasticsearch import ElasticsearchException, NotFoundError -__all__ = ('Scan', 'ScanError') +__all__ = ('CompositeAggregationScan', 'Scan', 'ScanError') logger = logging.getLogger('elasticsearch') @@ -140,3 +142,147 @@ def _update_state(self, resp): self._successful_shards = resp['_shards']['successful'] self._total_shards = resp['_shards']['total'] self._done = not self._hits or self._scroll_id is None + + +class CompositeAggregationScan: + + def __init__( + self, + es, + query, + raise_on_error=True, + prefetch_next_chunk=False, + **kwargs, + ): + self._es = es + self._query = deepcopy(query) + self._raise_on_error = raise_on_error + self._prefetch_next_chunk = prefetch_next_chunk + self._kwargs = kwargs + + self._aggs_key = self._extract_aggs_key() + + if 'composite' not in self._query['aggs'][self._aggs_key]: + raise RuntimeError( + 'Scroll available only for composite aggregations.', + ) + + self._after_key = None + + self._initial = True + self._done = False + self._buckets = [] + self._buckets_idx = 0 + + self._successful_shards = 0 + self._total_shards = 0 + self._prefetched = None + + def _extract_aggs_key(self): + try: + return list(self._query['aggs'].keys())[0] + except (KeyError, IndexError): + raise RuntimeError( + "Can't get aggregation key from query {query}." + .format(query=self._query), + ) + + async def __aenter__(self): # noqa + self._initial = False + await self._fetch_results() + + return self + + async def __aexit__(self, *exc_info): # noqa + self._reset_prefetched() + + def __aiter__(self): + if self._initial: + raise RuntimeError( + 'Scan operations should be done ' + 'inside async context manager.', + ) + + return self + + async def __anext__(self): + if self._done: + raise StopAsyncIteration + + if self._buckets_idx >= len(self._buckets): + if self._successful_shards < self._total_shards: + logger.warning( + 'Aggregation request has only succeeded ' + 'on %d shards out of %d.', + self._successful_shards, self._total_shards, + ) + if self._raise_on_error: + raise ElasticsearchException( + 'Aggregation request has only succeeded ' + 'on %d shards out of %d.' + .format(self._successful_shards, self._total_shards), + ) + + await self._fetch_results() + if self._done: + raise StopAsyncIteration + + ret = self._buckets[self._buckets_idx] + self._buckets_idx += 1 + + return ret + + async def _search(self): + found, resp = True, None + try: + resp = await self._es.search( + body=self._query, + **self._kwargs, + ) + except NotFoundError: + found = False + + return found, resp + + def _reset_prefetched(self): + if self._prefetched is not None and not self._prefetched.cancelled(): # noqa + self._prefetched.cancel() + + self._prefetched = None + + async def _fetch_results(self): + if self._prefetched is not None: + found, resp = await self._prefetched + self._reset_prefetched() + else: + found, resp = await self._search() + + if not found: + self._done = True + + return + + self._update_state(resp) + + if self._prefetch_next_chunk: + self._prefetched = asyncio.create_task( + self._search(), + ) + + def _update_query(self): + if self._after_key is None: + return + + self._query['aggs'][self._aggs_key]['composite']['after'] = self._after_key # noqa + + def _update_state(self, resp): + self._after_key = resp['aggregations'][self._aggs_key].get('after_key') + self._buckets = resp['aggregations'][self._aggs_key]['buckets'] + self._buckets_idx = 0 + + self._update_query() + + self._successful_shards = resp['_shards']['successful'] + self._total_shards = resp['_shards']['total'] + + self._done = not self._buckets or self._after_key is None diff --git a/tests/test_composite_scan.py b/tests/test_composite_scan.py new file mode 100644 index 00000000..bd7889f5 --- /dev/null +++ b/tests/test_composite_scan.py @@ -0,0 +1,275 @@ +import asyncio +import logging +from copy import deepcopy +from unittest import mock + +import pytest + +from aioelasticsearch import ElasticsearchException +from aioelasticsearch.helpers import CompositeAggregationScan + +logger = logging.getLogger('elasticsearch') + + +ES_DATA = [ + { + 'score': '1', + }, + { + 'score': '2', + }, + { + 'score': '2', + }, + { + 'score': '3', + }, + { + 'score': '3', + }, + { + 'score': '3', + }, +] + + +QUERY = { + 'aggs': { + 'buckets': { + 'composite': { + 'sources': [ + {'score': {'terms': {'field': 'score.keyword'}}}, + ], + }, + }, + }, +} + +INDEX = 'test_aioes' + + +@pytest.fixture +def populate_aggs_data(loop, es): + + async def do(index, docs): + coros = [] + + await es.indices.create(index) + + for i, doc in enumerate(docs): + coros.append( + es.index( + index=index, + id=str(i), + body=doc, + ), + ) + + await asyncio.gather(*coros, loop=loop) + await es.indices.refresh() + + return do + + +@pytest.mark.run_loop +async def test_async_for_without_context_manager(es): + scan = CompositeAggregationScan(es, QUERY) + + with pytest.raises(RuntimeError): + async for doc in scan: + doc + + +@pytest.mark.run_loop +async def test_non_aggregation_query(es): + with pytest.raises(RuntimeError): + CompositeAggregationScan( + es, + { + 'query': { + 'bool': { + 'match_all': {}, + }, + }, + }, + ) + + +@pytest.mark.run_loop +async def test_non_composite_aggregation(es): + with pytest.raises(RuntimeError): + CompositeAggregationScan( + es, + { + 'query': { + 'aggs': { + 'counts': { + 'value_count': {'field': 'domains'}, + }, + }, + }, + }, + ) + + +@pytest.mark.run_loop +async def test_scan(es, populate_aggs_data): + await populate_aggs_data(INDEX, ES_DATA) + + async with CompositeAggregationScan( + es, + QUERY, + index=INDEX, + ) as scan: + i = 1 + async for doc in scan: + assert doc == { + 'key': {'score': str(i)}, + 'doc_count': i, + } + + i += 1 + + assert i == 4 + + +@pytest.mark.run_loop +async def test_scan_no_index(es, populate_aggs_data): + await populate_aggs_data(INDEX, ES_DATA) + + async with CompositeAggregationScan( + es, + QUERY, + ) as scan: + i = 1 + async for doc in scan: + assert doc == { + 'key': {'score': str(i)}, + 'doc_count': i, + } + + i += 1 + + assert i == 4 + + +@pytest.mark.run_loop +async def test_scan_multiple_fetch(es, populate_aggs_data): + await populate_aggs_data(INDEX, ES_DATA) + + q = deepcopy(QUERY) + q['aggs']['buckets']['composite']['size'] = 1 + + async with CompositeAggregationScan( + es, + q, + index=INDEX, + ) as scan: + i = 1 + original_update = scan._update_state + mock_update = mock.MagicMock(side_effect=original_update) + with mock.patch.object(scan, '_update_state', mock_update): + + async for doc in scan: + assert doc == { + 'key': {'score': str(i)}, + 'doc_count': i, + } + + i += 1 + + assert mock_update.call_count == 3 + + +@pytest.mark.run_loop +async def test_scan_with_prefetch_next(es, populate_aggs_data): + await populate_aggs_data(INDEX, ES_DATA) + + q = deepcopy(QUERY) + q['aggs']['buckets']['composite']['size'] = 1 + + async with CompositeAggregationScan( + es, + q, + prefetch_next_chunk=True, + index=INDEX, + ) as scan: + original_reset_prefetch = scan._reset_prefetched + mock_reset_prefetch = mock.MagicMock( + side_effect=original_reset_prefetch, + ) + with mock.patch.object( + scan, + '_reset_prefetched', + mock_reset_prefetch, + ): + async for _ in scan: # noqa + assert scan._prefetched is not None + + assert mock_reset_prefetch.call_count == 3 + + last_prefetch_task = scan._prefetched + + await asyncio.sleep(0) + + assert last_prefetch_task.cancelled() + assert scan._prefetched is None + + +@pytest.mark.run_loop +async def test_scan_warning_on_failed_shards( + es, + populate_aggs_data, + mocker, +): + mocker.spy(logger, 'warning') + + await populate_aggs_data(INDEX, ES_DATA) + + async with CompositeAggregationScan( + es, + QUERY, + raise_on_error=False, + index=INDEX, + ) as scan: + i = 0 + async for doc in scan: # noqa + if i == 1: + scan._successful_shards = 4 + scan._total_shards = 5 + i += 1 + + logger.warning.assert_called_once_with( + 'Aggregation request has only succeeded on %d shards out of %d.', + 4, + 5, + ) + + +@pytest.mark.run_loop +async def test_scan_exception_on_failed_shards( + es, + populate_aggs_data, + mocker, +): + mocker.spy(logger, 'warning') + + await populate_aggs_data(INDEX, ES_DATA) + + async with CompositeAggregationScan( + es, + QUERY, + index=INDEX, + ) as scan: + i = 0 + with pytest.raises(ElasticsearchException): + async for doc in scan: # noqa + if i == 1: + scan._successful_shards = 4 + scan._total_shards = 5 + i += 1 + + assert i == 3 + logger.warning.assert_called_once_with( + 'Aggregation request has only succeeded on %d shards out of %d.', 4, 5, + ) From d9f3f3c2a9dbb44fe715fb93b4d19e46ee6bbc94 Mon Sep 17 00:00:00 2001 From: Petro Kostiukevych Date: Wed, 15 Jul 2020 13:31:09 +0300 Subject: [PATCH 2/3] comply with python vervion < 3.7; added README --- README.rst | 37 ++++++++++++++++++++++++++++++++++++ aioelasticsearch/helpers.py | 8 +++++++- tests/test_composite_scan.py | 16 ++++++++++++---- 3 files changed, 56 insertions(+), 5 deletions(-) diff --git a/README.rst b/README.rst index c1564528..74b63827 100644 --- a/README.rst +++ b/README.rst @@ -68,6 +68,43 @@ Asynchronous `scroll `_. + +.. code-block:: python + + import asyncio + + from aioelasticsearch import Elasticsearch + from aioelasticsearch.helpers import CompositeAggregationScan + + QUERY = { + 'aggs': { + 'buckets': { + 'composite': { + 'sources': [ + {'score': {'terms': {'field': 'score.keyword'}}}, + ], + 'size': 5, + }, + }, + }, + } + + async def go(): + async with Elasticsearch() as es: + async with CompositeAggregationScan( + es, + QUERY, + index='index', + ) as scan: + + async for doc in scan: + print(doc['doc_count'], doc['key']) + + loop = asyncio.get_event_loop() + loop.run_until_complete(go()) + loop.close() + Thanks ------ diff --git a/aioelasticsearch/helpers.py b/aioelasticsearch/helpers.py index 12b86cf9..01f45c87 100644 --- a/aioelasticsearch/helpers.py +++ b/aioelasticsearch/helpers.py @@ -150,6 +150,7 @@ def __init__( self, es, query, + loop=None, raise_on_error=True, prefetch_next_chunk=False, **kwargs, @@ -160,6 +161,11 @@ def __init__( self._prefetch_next_chunk = prefetch_next_chunk self._kwargs = kwargs + if loop is None: + loop = asyncio.get_event_loop() + + self._loop = loop + self._aggs_key = self._extract_aggs_key() if 'composite' not in self._query['aggs'][self._aggs_key]: @@ -265,7 +271,7 @@ async def _fetch_results(self): self._update_state(resp) if self._prefetch_next_chunk: - self._prefetched = asyncio.create_task( + self._prefetched = self._loop.create_task( self._search(), ) diff --git a/tests/test_composite_scan.py b/tests/test_composite_scan.py index bd7889f5..a1c33dda 100644 --- a/tests/test_composite_scan.py +++ b/tests/test_composite_scan.py @@ -113,12 +113,13 @@ async def test_non_composite_aggregation(es): @pytest.mark.run_loop -async def test_scan(es, populate_aggs_data): +async def test_scan(loop, es, populate_aggs_data): await populate_aggs_data(INDEX, ES_DATA) async with CompositeAggregationScan( es, QUERY, + loop=loop, index=INDEX, ) as scan: i = 1 @@ -134,12 +135,13 @@ async def test_scan(es, populate_aggs_data): @pytest.mark.run_loop -async def test_scan_no_index(es, populate_aggs_data): +async def test_scan_no_index(loop, es, populate_aggs_data): await populate_aggs_data(INDEX, ES_DATA) async with CompositeAggregationScan( es, QUERY, + loop=loop, ) as scan: i = 1 async for doc in scan: @@ -154,7 +156,7 @@ async def test_scan_no_index(es, populate_aggs_data): @pytest.mark.run_loop -async def test_scan_multiple_fetch(es, populate_aggs_data): +async def test_scan_multiple_fetch(loop, es, populate_aggs_data): await populate_aggs_data(INDEX, ES_DATA) q = deepcopy(QUERY) @@ -163,6 +165,7 @@ async def test_scan_multiple_fetch(es, populate_aggs_data): async with CompositeAggregationScan( es, q, + loop=loop, index=INDEX, ) as scan: i = 1 @@ -182,7 +185,7 @@ async def test_scan_multiple_fetch(es, populate_aggs_data): @pytest.mark.run_loop -async def test_scan_with_prefetch_next(es, populate_aggs_data): +async def test_scan_with_prefetch_next(loop, es, populate_aggs_data): await populate_aggs_data(INDEX, ES_DATA) q = deepcopy(QUERY) @@ -191,6 +194,7 @@ async def test_scan_with_prefetch_next(es, populate_aggs_data): async with CompositeAggregationScan( es, q, + loop=loop, prefetch_next_chunk=True, index=INDEX, ) as scan: @@ -218,6 +222,7 @@ async def test_scan_with_prefetch_next(es, populate_aggs_data): @pytest.mark.run_loop async def test_scan_warning_on_failed_shards( + loop, es, populate_aggs_data, mocker, @@ -229,6 +234,7 @@ async def test_scan_warning_on_failed_shards( async with CompositeAggregationScan( es, QUERY, + loop=loop, raise_on_error=False, index=INDEX, ) as scan: @@ -248,6 +254,7 @@ async def test_scan_warning_on_failed_shards( @pytest.mark.run_loop async def test_scan_exception_on_failed_shards( + loop, es, populate_aggs_data, mocker, @@ -259,6 +266,7 @@ async def test_scan_exception_on_failed_shards( async with CompositeAggregationScan( es, QUERY, + loop=loop, index=INDEX, ) as scan: i = 0 From 05c358b92a9c181baccb4b2dee21a4c60892a7d7 Mon Sep 17 00:00:00 2001 From: Petro Kostiukevych Date: Wed, 15 Jul 2020 13:35:26 +0300 Subject: [PATCH 3/3] comply with python 3.5 --- aioelasticsearch/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aioelasticsearch/helpers.py b/aioelasticsearch/helpers.py index 01f45c87..9f4041a7 100644 --- a/aioelasticsearch/helpers.py +++ b/aioelasticsearch/helpers.py @@ -153,7 +153,7 @@ def __init__( loop=None, raise_on_error=True, prefetch_next_chunk=False, - **kwargs, + **kwargs ): self._es = es self._query = deepcopy(query)