Skip to content

Commit 1aa719a

Browse files
shortcutstkrugg
andauthored
fix(iterators): prevent recursion issue (#542)
Co-authored-by: Youcef Mammar <[email protected]>
1 parent 8db4220 commit 1aa719a

File tree

3 files changed

+50
-3
lines changed

3 files changed

+50
-3
lines changed

algoliasearch/iterators.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def __iter__(self):
3535

3636

3737
class PaginatorIterator(Iterator):
38+
nbHits = 0
39+
3840
def __init__(self, transporter, index_name, request_options=None):
3941
# type: (Transporter, str, Optional[Union[dict, RequestOptions]]) -> None # noqa: E501
4042

@@ -48,16 +50,16 @@ def __init__(self, transporter, index_name, request_options=None):
4850

4951
def __next__(self):
5052
# type: () -> dict
51-
5253
if self._raw_response:
54+
5355
if len(self._raw_response["hits"]):
5456
hit = self._raw_response["hits"].pop(0)
5557

5658
hit.pop("_highlightResult")
5759

5860
return hit
5961

60-
if self._raw_response["nbHits"] < self._data["hitsPerPage"]:
62+
if self.nbHits < self._data["hitsPerPage"]:
6163
self._raw_response = {}
6264
self._data = {
6365
"hitsPerPage": 1000,
@@ -68,6 +70,7 @@ def __next__(self):
6870
self._raw_response = self._transporter.read(
6971
Verb.POST, self.get_endpoint(), self._data, self._request_options
7072
)
73+
self.nbHits = len(self._raw_response["hits"])
7174

7275
self._data["page"] += 1
7376

algoliasearch/iterators_async.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111

1212
class PaginatorIteratorAsync(Iterator):
13+
nbHits = 0
14+
1315
def __init__(self, transporter, index_name, request_options=None):
1416
# type: (Transporter, str, Optional[Union[dict, RequestOptions]]) -> None # noqa: E501
1517

@@ -38,7 +40,7 @@ def __anext__(self): # type: ignore
3840

3941
return hit
4042

41-
if self._raw_response["nbHits"] < self._data["hitsPerPage"]:
43+
if self.nbHits < self._data["hitsPerPage"]:
4244
self._raw_response = {}
4345
self._data = {
4446
"hitsPerPage": 1000,
@@ -49,6 +51,7 @@ def __anext__(self): # type: ignore
4951
self._raw_response = yield from self._transporter.read(
5052
Verb.POST, self.get_endpoint(), self._data, self._request_options
5153
)
54+
self.nbHits = len(self._raw_response["hits"])
5255

5356
self._data["page"] += 1
5457

tests/features/test_search_index.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
# -*- coding: utf-8 -*-
22
import sys
33
import unittest
4+
import json
5+
import requests
6+
7+
from requests.models import Response
48

59
from algoliasearch.exceptions import RequestException, ObjectNotFoundException
610
from algoliasearch.responses import MultipleResponse
11+
from algoliasearch.search_client import SearchClient
712
from algoliasearch.search_index import SearchIndex
813
from tests.helpers.factory import Factory as F
914
from tests.helpers.misc import Unicode, rule_without_metadata
15+
from unittest.mock import MagicMock
1016

1117

1218
class TestSearchIndex(unittest.TestCase):
@@ -450,6 +456,41 @@ def test_synonyms(self):
450456
# and check that the number of returned synonyms is equal to 0
451457
self.assertEqual(self.index.search_synonyms("")["nbHits"], 0)
452458

459+
def test_browse_rules(self):
460+
def side_effect(req, **kwargs):
461+
hits = [{"objectID": i, "_highlightResult": None} for i in range(0, 1000)]
462+
page = json.loads(req.body)["page"]
463+
464+
if page == 3:
465+
hits = hits[0:800]
466+
467+
response = Response()
468+
response.status_code = 200
469+
response._content = str.encode(
470+
json.dumps(
471+
{
472+
"hits": hits,
473+
"nbHits": 3800,
474+
"page": page,
475+
"nbPages": 3,
476+
}
477+
)
478+
)
479+
480+
return response
481+
482+
client = SearchClient.create("foo", "bar")
483+
client._transporter._requester._session = requests.Session()
484+
client._transporter._requester._session.send = MagicMock(name="send")
485+
client._transporter._requester._session.send.side_effect = side_effect
486+
index = F.index(client, "test")
487+
488+
rules = index.browse_rules()
489+
490+
len_rules = len(list(rules))
491+
492+
self.assertEqual(len_rules, 3800)
493+
453494
def test_rules(self):
454495
responses = MultipleResponse()
455496

0 commit comments

Comments
 (0)