Skip to content

Commit f38568e

Browse files
committed
add draft messages
1 parent 564e947 commit f38568e

File tree

9 files changed

+448
-3
lines changed

9 files changed

+448
-3
lines changed

stream_chat/async_chat/channel.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, Dict, Iterable, List, Union
2+
from typing import Any, Dict, Iterable, List, Optional, Union
33

44
from stream_chat.base.channel import ChannelInterface, add_user_id
55
from stream_chat.base.exceptions import StreamChannelException
@@ -247,3 +247,23 @@ async def update_member_partial(
247247

248248
payload = {"set": to_set or {}, "unset": to_unset or []}
249249
return await self.client.patch(f"{self.url}/member/{user_id}", data=payload)
250+
251+
async def create_draft(self, message: Dict, user_id: str) -> StreamResponse:
252+
payload = {"message": add_user_id(message, user_id)}
253+
return await self.client.post(f"{self.url}/draft", data=payload)
254+
255+
async def delete_draft(
256+
self, user_id: str, parent_id: Optional[str] = None
257+
) -> StreamResponse:
258+
params = {"user_id": user_id}
259+
if parent_id:
260+
params["parent_id"] = parent_id
261+
return await self.client.delete(f"{self.url}/draft", params=params)
262+
263+
async def get_draft(
264+
self, user_id: str, parent_id: Optional[str] = None
265+
) -> StreamResponse:
266+
params = {"user_id": user_id}
267+
if parent_id:
268+
params["parent_id"] = parent_id
269+
return await self.client.get(f"{self.url}/draft", params=params)

stream_chat/async_chat/client.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from stream_chat.async_chat.segment import Segment
2222
from stream_chat.types.base import SortParam
2323
from stream_chat.types.campaign import CampaignData, QueryCampaignsOptions
24+
from stream_chat.types.draft import QueryDraftsFilter, QueryDraftsOptions
2425
from stream_chat.types.segment import (
2526
QuerySegmentsOptions,
2627
QuerySegmentTargetsOptions,
@@ -825,6 +826,28 @@ async def unread_counts(self, user_id: str) -> StreamResponse:
825826
async def unread_counts_batch(self, user_ids: List[str]) -> StreamResponse:
826827
return await self.post("unread_batch", data={"user_ids": user_ids})
827828

829+
async def query_drafts(
830+
self,
831+
user_id: str,
832+
filter: Optional[QueryDraftsFilter] = None,
833+
sort: Optional[List[SortParam]] = None,
834+
options: Optional[QueryDraftsOptions] = None,
835+
) -> StreamResponse:
836+
data: Dict[str, Union[str, Dict[str, Any], List[SortParam]]] = {
837+
"user_id": user_id
838+
}
839+
840+
if filter is not None:
841+
data["filter"] = cast(dict, filter)
842+
843+
if sort is not None:
844+
data["sort"] = cast(dict, sort)
845+
846+
if options is not None:
847+
data.update(cast(dict, options))
848+
849+
return await self.post("drafts/query", data=data)
850+
828851
async def close(self) -> None:
829852
await self.session.close()
830853

stream_chat/base/channel.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import abc
2-
from typing import Any, Awaitable, Dict, Iterable, List, Union
2+
from typing import Any, Awaitable, Dict, Iterable, List, Optional, Union
33

44
from stream_chat.base.client import StreamChatInterface
55
from stream_chat.base.exceptions import StreamChannelException
@@ -488,6 +488,45 @@ def update_member_partial(
488488
"""
489489
pass
490490

491+
@abc.abstractmethod
492+
def create_draft(
493+
self, message: Dict, user_id: str
494+
) -> Union[StreamResponse, Awaitable[StreamResponse]]:
495+
"""
496+
Creates or updates a draft message in a channel.
497+
498+
:param message: The message object
499+
:param user_id: The ID of the user creating the draft
500+
:return: The Server Response
501+
"""
502+
pass
503+
504+
@abc.abstractmethod
505+
def delete_draft(
506+
self, user_id: str, parent_id: Optional[str] = None
507+
) -> Union[StreamResponse, Awaitable[StreamResponse]]:
508+
"""
509+
Deletes a draft message from a channel.
510+
511+
:param user_id: The ID of the user who owns the draft
512+
:param parent_id: Optional ID of the parent message if this is a thread draft
513+
:return: The Server Response
514+
"""
515+
pass
516+
517+
@abc.abstractmethod
518+
def get_draft(
519+
self, user_id: str, parent_id: Optional[str] = None
520+
) -> Union[StreamResponse, Awaitable[StreamResponse]]:
521+
"""
522+
Retrieves a draft message from a channel.
523+
524+
:param user_id: The ID of the user who owns the draft
525+
:param parent_id: Optional ID of the parent message if this is a thread draft
526+
:return: The Server Response
527+
"""
528+
pass
529+
491530

492531
def add_user_id(payload: Dict, user_id: str) -> Dict:
493532
return {**payload, "user": {"id": user_id}}

stream_chat/base/client.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from stream_chat.types.base import SortParam
1111
from stream_chat.types.campaign import CampaignData, QueryCampaignsOptions
12+
from stream_chat.types.draft import QueryDraftsFilter, QueryDraftsOptions
1213
from stream_chat.types.segment import (
1314
QuerySegmentsOptions,
1415
QuerySegmentTargetsOptions,
@@ -1384,6 +1385,16 @@ def unread_counts_batch(
13841385
"""
13851386
pass
13861387

1388+
@abc.abstractmethod
1389+
def query_drafts(
1390+
self,
1391+
user_id: str,
1392+
filter: Optional[QueryDraftsFilter] = None,
1393+
sort: Optional[List[SortParam]] = None,
1394+
options: Optional[QueryDraftsOptions] = None,
1395+
) -> Union[StreamResponse, Awaitable[StreamResponse]]:
1396+
pass
1397+
13871398
#####################
13881399
# Private methods #
13891400
#####################

stream_chat/channel.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, Dict, Iterable, List, Union
2+
from typing import Any, Dict, Iterable, List, Optional, Union
33

44
from stream_chat.base.channel import ChannelInterface, add_user_id
55
from stream_chat.base.exceptions import StreamChannelException
@@ -248,3 +248,26 @@ def update_member_partial(
248248

249249
payload = {"set": to_set or {}, "unset": to_unset or []}
250250
return self.client.patch(f"{self.url}/member/{user_id}", data=payload)
251+
252+
def create_draft(self, message: Dict, user_id: str) -> StreamResponse:
253+
message["user_id"] = user_id
254+
payload = {"message": message}
255+
return self.client.post(f"{self.url}/draft", data=payload)
256+
257+
def delete_draft(
258+
self, user_id: str, parent_id: Optional[str] = None
259+
) -> StreamResponse:
260+
params = {"user_id": user_id}
261+
if parent_id:
262+
params["parent_id"] = parent_id
263+
264+
return self.client.delete(f"{self.url}/draft", params=params)
265+
266+
def get_draft(
267+
self, user_id: str, parent_id: Optional[str] = None
268+
) -> StreamResponse:
269+
params = {"user_id": user_id}
270+
if parent_id:
271+
params["parent_id"] = parent_id
272+
273+
return self.client.get(f"{self.url}/draft", params=params)

stream_chat/client.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from stream_chat.segment import Segment
1111
from stream_chat.types.base import SortParam
1212
from stream_chat.types.campaign import CampaignData, QueryCampaignsOptions
13+
from stream_chat.types.draft import QueryDraftsFilter, QueryDraftsOptions
1314
from stream_chat.types.segment import (
1415
QuerySegmentsOptions,
1516
QuerySegmentTargetsOptions,
@@ -782,3 +783,21 @@ def unread_counts(self, user_id: str) -> StreamResponse:
782783

783784
def unread_counts_batch(self, user_ids: List[str]) -> StreamResponse:
784785
return self.post("unread_batch", data={"user_ids": user_ids})
786+
787+
def query_drafts(
788+
self,
789+
user_id: str,
790+
filter: Optional[QueryDraftsFilter] = None,
791+
sort: Optional[List[SortParam]] = None,
792+
options: Optional[QueryDraftsOptions] = None,
793+
) -> StreamResponse:
794+
data: Dict[str, Union[str, Dict[str, Any], List[SortParam]]] = {
795+
"user_id": user_id
796+
}
797+
if filter is not None:
798+
data["filter"] = cast(dict, filter)
799+
if sort is not None:
800+
data["sort"] = cast(dict, sort)
801+
if options is not None:
802+
data.update(cast(dict, options))
803+
return self.post("drafts/query", data=data)
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import uuid
2+
from typing import Dict
3+
4+
import pytest
5+
6+
from stream_chat.async_chat.channel import Channel
7+
from stream_chat.async_chat.client import StreamChatAsync
8+
from stream_chat.types.base import SortOrder
9+
10+
11+
@pytest.mark.incremental
12+
class TestDraft:
13+
async def test_create_draft(self, channel: Channel, random_user: Dict):
14+
draft_message = {"text": "This is a draft message"}
15+
response = await channel.create_draft(draft_message, random_user["id"])
16+
17+
assert "draft" in response
18+
assert response["draft"]["message"]["text"] == "This is a draft message"
19+
assert response["draft"]["channel_cid"] == channel.cid
20+
21+
async def test_get_draft(self, channel: Channel, random_user: Dict):
22+
# First create a draft
23+
draft_message = {"text": "This is a draft to retrieve"}
24+
await channel.create_draft(draft_message, random_user["id"])
25+
26+
# Then get the draft
27+
response = await channel.get_draft(random_user["id"])
28+
29+
assert "draft" in response
30+
assert response["draft"]["message"]["text"] == "This is a draft to retrieve"
31+
assert response["draft"]["channel_cid"] == channel.cid
32+
33+
async def test_delete_draft(self, channel: Channel, random_user: Dict):
34+
# First create a draft
35+
draft_message = {"text": "This is a draft to delete"}
36+
await channel.create_draft(draft_message, random_user["id"])
37+
38+
# Then delete the draft
39+
await channel.delete_draft(random_user["id"])
40+
41+
# Verify it's deleted by trying to get it
42+
try:
43+
await channel.get_draft(random_user["id"])
44+
raise AssertionError("Draft should be deleted")
45+
except Exception:
46+
# Expected behavior, draft should not be found
47+
pass
48+
49+
async def test_thread_draft(self, channel: Channel, random_user: Dict):
50+
# First create a parent message
51+
msg = await channel.send_message({"text": "Parent message"}, random_user["id"])
52+
parent_id = msg["message"]["id"]
53+
54+
# Create a draft reply
55+
draft_reply = {"text": "This is a draft reply", "parent_id": parent_id}
56+
response = await channel.create_draft(draft_reply, random_user["id"])
57+
58+
assert "draft" in response
59+
assert response["draft"]["message"]["text"] == "This is a draft reply"
60+
assert response["draft"]["parent_id"] == parent_id
61+
62+
# Get the draft reply
63+
response = await channel.get_draft(random_user["id"], parent_id=parent_id)
64+
65+
assert "draft" in response
66+
assert response["draft"]["message"]["text"] == "This is a draft reply"
67+
assert response["draft"]["parent_id"] == parent_id
68+
69+
# Delete the draft reply
70+
await channel.delete_draft(random_user["id"], parent_id=parent_id)
71+
72+
# Verify it's deleted
73+
try:
74+
await channel.get_draft(random_user["id"], parent_id=parent_id)
75+
raise AssertionError("Thread draft should be deleted")
76+
except Exception:
77+
# Expected behavior
78+
pass
79+
80+
async def test_query_drafts(
81+
self, client: StreamChatAsync, channel: Channel, random_user: Dict
82+
):
83+
# Create multiple drafts in different channels
84+
draft1 = {"text": "Draft in channel 1"}
85+
await channel.create_draft(draft1, random_user["id"])
86+
87+
# Create another channel with a draft
88+
channel2 = client.channel("messaging", str(uuid.uuid4()))
89+
await channel2.create(random_user["id"])
90+
91+
draft2 = {"text": "Draft in channel 2"}
92+
await channel2.create_draft(draft2, random_user["id"])
93+
94+
# Query all drafts for the user
95+
response = await client.query_drafts(random_user["id"])
96+
97+
assert "drafts" in response
98+
assert len(response["drafts"]) == 2
99+
100+
# Query drafts for a specific channel
101+
response = await client.query_drafts(
102+
random_user["id"], filter={"channel_cid": channel2.cid}
103+
)
104+
105+
assert "drafts" in response
106+
assert len(response["drafts"]) == 1
107+
draft = response["drafts"][0]
108+
assert draft["channel_cid"] == channel2.cid
109+
assert draft["message"]["text"] == "Draft in channel 2"
110+
111+
# Query drafts with sort
112+
response = await client.query_drafts(
113+
random_user["id"],
114+
sort=[{"field": "created_at", "direction": SortOrder.ASC}],
115+
)
116+
117+
assert "drafts" in response
118+
assert len(response["drafts"]) == 2
119+
assert response["drafts"][0]["channel_cid"] == channel.cid
120+
assert response["drafts"][1]["channel_cid"] == channel2.cid
121+
122+
# Query drafts with pagination
123+
response = await client.query_drafts(
124+
random_user["id"],
125+
options={"limit": 1},
126+
)
127+
128+
assert "drafts" in response
129+
assert len(response["drafts"]) == 1
130+
assert response["drafts"][0]["channel_cid"] == channel2.cid
131+
132+
assert response["next"] is not None
133+
134+
# Query drafts with pagination
135+
response = await client.query_drafts(
136+
random_user["id"],
137+
options={"limit": 1, "next": response["next"]},
138+
)
139+
140+
assert "drafts" in response
141+
assert len(response["drafts"]) == 1
142+
assert response["drafts"][0]["channel_cid"] == channel.cid
143+
144+
# Cleanup
145+
try:
146+
await channel2.delete()
147+
except Exception:
148+
pass

0 commit comments

Comments
 (0)