Skip to content

Commit 896f6e2

Browse files
authored
Feature: Search improvements (#63)
1 parent 6178047 commit 896f6e2

File tree

5 files changed

+152
-7
lines changed

5 files changed

+152
-7
lines changed

stream_chat/async_chat/client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,11 @@ async def get_rate_limits(
309309

310310
return await self.get("rate_limits", params)
311311

312-
async def search(self, filter_conditions, query, **options):
313-
params = {**options, "filter_conditions": filter_conditions, "query": query}
312+
async def search(self, filter_conditions, query, sort=None, **options):
313+
if "offset" in options:
314+
if sort or "next" in options:
315+
raise ValueError("cannot use offset with sort or next parameters")
316+
params = self.create_search_params(filter_conditions, query, sort, **options)
314317
return await self.get("search", params={"payload": json.dumps(params)})
315318

316319
async def send_file(self, uri, url, name, user, content_type=None):

stream_chat/base/client.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,17 @@ def create_token(self, user_id, exp=None, iat=None, **claims):
4646
payload["iat"] = iat
4747
return jwt.encode(payload, self.api_secret, algorithm="HS256")
4848

49+
def create_search_params(self, filter_conditions, query, sort, **options):
50+
params = options.copy()
51+
if isinstance(query, str):
52+
params.update({"query": query})
53+
else:
54+
params.update({"message_filter_conditions": query})
55+
params.update(
56+
{"filter_conditions": filter_conditions, "sort": self.normalize_sort(sort)}
57+
)
58+
return params
59+
4960
def verify_webhook(self, request_body, x_signature):
5061
"""
5162
Verify the signature added to a webhook event

stream_chat/client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,8 +304,11 @@ def get_rate_limits(
304304

305305
return self.get("rate_limits", params)
306306

307-
def search(self, filter_conditions, query, **options):
308-
params = {**options, "filter_conditions": filter_conditions, "query": query}
307+
def search(self, filter_conditions, query, sort=None, **options):
308+
if "offset" in options:
309+
if sort or "next" in options:
310+
raise ValueError("cannot use offset with sort or next parameters")
311+
params = self.create_search_params(filter_conditions, query, sort, **options)
309312
return self.get("search", params={"payload": json.dumps(params)})
310313

311314
def send_file(self, uri, url, name, user, content_type=None):

stream_chat/tests/async_chat/test_client.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from stream_chat.base.exceptions import StreamAPIException
1111

1212

13-
@pytest.mark.incremental
1413
class TestClient(object):
1514
def test_normalize_sort(self, client):
1615
expected = [
@@ -315,8 +314,36 @@ async def test_get_rate_limits(self, event_loop, client):
315314
> response["server_side"]["GetRateLimits"]["remaining"]
316315
)
317316

317+
@pytest.mark.xfail
318318
@pytest.mark.asyncio
319-
async def test_search(self, event_loop, client, channel, random_user):
319+
async def test_search_with_sort(self, client, channel, random_user):
320+
text = str(uuid.uuid4())
321+
ids = ["0" + text, "1" + text]
322+
await channel.send_message(
323+
{"text": text, "id": ids[0]},
324+
random_user["id"],
325+
)
326+
await channel.send_message(
327+
{"text": text, "id": ids[1]},
328+
random_user["id"],
329+
)
330+
response = await client.search(
331+
{"type": "messaging"}, text, **{"limit": 1, "sort": [{"created_at": -1}]}
332+
)
333+
# searches all channels so make sure at least one is found
334+
assert len(response["results"]) >= 1
335+
assert response["next"] is not None
336+
assert ids[1] == response["results"][0]["message"]["id"]
337+
response = await client.search(
338+
{"type": "messaging"}, text, **{"limit": 1, "next": response["next"]}
339+
)
340+
assert len(response["results"]) >= 1
341+
assert response["previous"] is not None
342+
assert response["next"] is None
343+
assert ids[0] == response["results"][0]["message"]["id"]
344+
345+
@pytest.mark.asyncio
346+
async def test_search(self, client, channel, random_user):
320347
query = "supercalifragilisticexpialidocious"
321348
await channel.send_message(
322349
{"text": f"How many syllables are there in {query}?"},
@@ -337,6 +364,45 @@ async def test_search(self, event_loop, client, channel, random_user):
337364
for message in response["results"]:
338365
assert query not in message["message"]["text"]
339366

367+
@pytest.mark.asyncio
368+
async def test_search_message_filters(self, client, channel, random_user):
369+
query = "supercalifragilisticexpialidocious"
370+
await channel.send_message(
371+
{"text": f"How many syllables are there in {query}?"},
372+
random_user["id"],
373+
)
374+
await channel.send_message(
375+
{"text": "Does 'cious' count as one or two?"}, random_user["id"]
376+
)
377+
response = await client.search(
378+
{"type": "messaging"},
379+
{"text": {"$q": query}},
380+
**{
381+
"limit": 2,
382+
"offset": 0,
383+
},
384+
)
385+
assert len(response["results"]) >= 1
386+
assert query in response["results"][0]["message"]["text"]
387+
388+
@pytest.mark.asyncio
389+
async def test_search_offset_with_sort(self, client):
390+
query = "supercalifragilisticexpialidocious"
391+
with pytest.raises(ValueError):
392+
await client.search(
393+
{"type": "messaging"},
394+
query,
395+
**{"limit": 2, "offset": 1, "sort": [{"created_at": -1}]},
396+
)
397+
398+
@pytest.mark.asyncio
399+
async def test_search_offset_with_next(self, client):
400+
query = "supercalifragilisticexpialidocious"
401+
with pytest.raises(ValueError):
402+
await client.search(
403+
{"type": "messaging"}, query, **{"limit": 2, "offset": 1, "next": query}
404+
)
405+
340406
@pytest.mark.asyncio
341407
async def test_query_channels_members_in(
342408
self, event_loop, client, fellowship_of_the_ring

stream_chat/tests/test_client.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from stream_chat.base.exceptions import StreamAPIException
1111

1212

13-
@pytest.mark.incremental
1413
class TestClient(object):
1514
def test_normalize_sort(self, client):
1615
expected = [
@@ -258,6 +257,33 @@ def test_get_rate_limits(self, client):
258257
> response["server_side"]["GetRateLimits"]["remaining"]
259258
)
260259

260+
@pytest.mark.xfail
261+
def test_search_with_sort(self, client, channel, random_user):
262+
text = str(uuid.uuid4())
263+
ids = ["0" + text, "1" + text]
264+
channel.send_message(
265+
{"text": text, "id": ids[0]},
266+
random_user["id"],
267+
)
268+
channel.send_message(
269+
{"text": text, "id": ids[1]},
270+
random_user["id"],
271+
)
272+
response = client.search(
273+
{"type": "messaging"}, text, **{"limit": 1, "sort": [{"created_at": -1}]}
274+
)
275+
# searches all channels so make sure at least one is found
276+
assert len(response["results"]) >= 1
277+
assert response["next"] is not None
278+
assert ids[1] == response["results"][0]["message"]["id"]
279+
response = client.search(
280+
{"type": "messaging"}, text, **{"limit": 1, "next": response["next"]}
281+
)
282+
assert len(response["results"]) >= 1
283+
assert response["previous"] is not None
284+
assert response["next"] is None
285+
assert ids[0] == response["results"][0]["message"]["id"]
286+
261287
def test_search(self, client, channel, random_user):
262288
query = "supercalifragilisticexpialidocious"
263289
channel.send_message(
@@ -279,6 +305,42 @@ def test_search(self, client, channel, random_user):
279305
for message in response["results"]:
280306
assert query not in message["message"]["text"]
281307

308+
def test_search_message_filters(self, client, channel, random_user):
309+
query = "supercalifragilisticexpialidocious"
310+
channel.send_message(
311+
{"text": f"How many syllables are there in {query}?"},
312+
random_user["id"],
313+
)
314+
channel.send_message(
315+
{"text": "Does 'cious' count as one or two?"}, random_user["id"]
316+
)
317+
response = client.search(
318+
{"type": "messaging"},
319+
{"text": {"$q": query}},
320+
**{
321+
"limit": 2,
322+
"offset": 0,
323+
},
324+
)
325+
assert len(response["results"]) >= 1
326+
assert query in response["results"][0]["message"]["text"]
327+
328+
def test_search_offset_with_sort(self, client):
329+
query = "supercalifragilisticexpialidocious"
330+
with pytest.raises(ValueError):
331+
client.search(
332+
{"type": "messaging"},
333+
query,
334+
**{"limit": 2, "offset": 1, "sort": [{"created_at": -1}]},
335+
)
336+
337+
def test_search_offset_with_next(self, client):
338+
query = "supercalifragilisticexpialidocious"
339+
with pytest.raises(ValueError):
340+
client.search(
341+
{"type": "messaging"}, query, **{"limit": 2, "offset": 1, "next": query}
342+
)
343+
282344
def test_query_channels_members_in(self, client, fellowship_of_the_ring):
283345
response = client.query_channels({"members": {"$in": ["gimli"]}}, {"id": 1})
284346
assert len(response["channels"]) == 1

0 commit comments

Comments
 (0)