Skip to content

Commit 5099bbf

Browse files
[CHAT-1319] Added ability to provide sort options as list to guarantee field order (#42)
Co-authored-by: Guyon Morée <[email protected]>
1 parent 671abfb commit 5099bbf

File tree

4 files changed

+49
-16
lines changed

4 files changed

+49
-16
lines changed

stream_chat/channel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def query_members(self, filter_conditions, sort=None, **options):
107107
108108
eg.
109109
channel.query_members(filter_conditions={"name": "tommaso"},
110-
sort=[{"field": "created_at", "direction": -1}],
110+
sort=[{"created_at": -1}, {"updated_at": 1}],
111111
offset=0,
112112
limit=10)
113113
"""
@@ -116,7 +116,7 @@ def query_members(self, filter_conditions, sort=None, **options):
116116
"id": self.id,
117117
"type": self.channel_type,
118118
"filter_conditions": filter_conditions,
119-
"sort": sort or [],
119+
"sort": self.client.normalize_sort(sort),
120120
**options,
121121
}
122122
response = self.client.get("members", params={"payload": json.dumps(payload)})

stream_chat/client.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections
12
from urllib.parse import urlparse
23
import hmac
34
import hashlib
@@ -75,6 +76,19 @@ def _make_request(self, method, relative_url, params=None, data=None):
7576
)
7677
return self._parse_response(response)
7778

79+
def normalize_sort(self, sort=None):
80+
sort_fields = []
81+
if isinstance(sort, collections.abc.Mapping):
82+
sort = [sort]
83+
if isinstance(sort, list):
84+
for item in sort:
85+
if "field" in item and "direction" in item:
86+
sort_fields.append(item)
87+
else:
88+
for k, v in item.items():
89+
sort_fields.append({"field": k, "direction": v})
90+
return sort_fields
91+
7892
def put(self, relative_url, params=None, data=None):
7993
return self._make_request(self.session.put, relative_url, params, data)
8094

@@ -189,20 +203,18 @@ def get_message(self, message_id):
189203
return self.get("messages/{}".format(message_id))
190204

191205
def query_users(self, filter_conditions, sort=None, **options):
192-
sort_fields = []
193-
if sort is not None:
194-
sort_fields = [{"field": k, "direction": v} for k, v in sort.items()]
195206
params = options.copy()
196-
params.update({"filter_conditions": filter_conditions, "sort": sort_fields})
207+
params.update(
208+
{"filter_conditions": filter_conditions, "sort": self.normalize_sort(sort)}
209+
)
197210
return self.get("users", params={"payload": json.dumps(params)})
198211

199212
def query_channels(self, filter_conditions, sort=None, **options):
200213
params = {"state": True, "watch": False, "presence": False}
201-
sort_fields = []
202-
if sort is not None:
203-
sort_fields = [{"field": k, "direction": v} for k, v in sort.items()]
204214
params.update(options)
205-
params.update({"filter_conditions": filter_conditions, "sort": sort_fields})
215+
params.update(
216+
{"filter_conditions": filter_conditions, "sort": self.normalize_sort(sort)}
217+
)
206218
return self.get("channels", params={"payload": json.dumps(params)})
207219

208220
def create_channel_type(self, data):

stream_chat/tests/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ def command(client):
7676
dict(name=str(uuid.uuid4()), description="My command")
7777
)
7878

79-
return response["command"]
79+
yield response["command"]
80+
81+
client.delete_command(response["command"]["name"])
8082

8183

8284
@pytest.fixture(scope="module")

stream_chat/tests/test_client.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,36 @@
1+
from operator import itemgetter
2+
13
import jwt
24
import pytest
5+
import sys
36
import uuid
47
from stream_chat import StreamChat
58
from stream_chat.exceptions import StreamAPIException
69

710

811
@pytest.mark.incremental
912
class TestClient(object):
13+
def test_normalize_sort(self, client):
14+
expected = [
15+
{"field": "field1", "direction": 1},
16+
{"field": "field2", "direction": -1},
17+
]
18+
actual = client.normalize_sort([{"field1": 1}, {"field2": -1}])
19+
assert actual == expected
20+
actual = client.normalize_sort(
21+
[{"field": "field1", "direction": 1}, {"field": "field2", "direction": -1}]
22+
)
23+
assert actual == expected
24+
actual = client.normalize_sort({"field1": 1})
25+
assert actual == [{"field": "field1", "direction": 1}]
26+
# The following example is not recommended because the order of the fields is not guaranteed in Python < 3.7
27+
actual = client.normalize_sort({"field1": 1, "field2": -1})
28+
if sys.version_info >= (3, 7):
29+
assert actual == expected
30+
else:
31+
# Compare elements regardless of the order
32+
assert sorted(actual, key=itemgetter("field")) == expected
33+
1034
def test_mute_user(self, client, random_users):
1135
response = client.mute_user(random_users[0]["id"], random_users[1]["id"])
1236
assert "mute" in response
@@ -61,11 +85,6 @@ def test_update_command(self, client, command):
6185
assert "command" in response
6286
assert "My new command" == response["command"]["description"]
6387

64-
def test_delete_command(self, client, command):
65-
response = client.delete_command(command["name"])
66-
with pytest.raises(StreamAPIException):
67-
client.get_command(command["name"])
68-
6988
def test_list_commands(self, client):
7089
response = client.list_commands()
7190
assert "commands" in response

0 commit comments

Comments
 (0)