Skip to content

Commit 6bb0e4c

Browse files
MartinCupelaarnautov-anton
authored andcommitted
feat: use Thread class instance to manage Thread component state when drafts enabled
1 parent 10251ec commit 6bb0e4c

File tree

8 files changed

+139
-26
lines changed

8 files changed

+139
-26
lines changed

src/components/Channel/Channel.tsx

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,14 +1145,42 @@ const ChannelInner = <V extends CustomTrigger = CustomTrigger>(
11451145

11461146
const openThread = (message: StreamMessage, event?: React.BaseSyntheticEvent) => {
11471147
event?.preventDefault();
1148-
setQuotedMessage((current) => {
1149-
if (current?.parent_id !== message?.parent_id) {
1150-
return undefined;
1151-
} else {
1152-
return current;
1148+
if (messageDraftsEnabled) {
1149+
let threadInstance = client.threads.threadsById[message.id];
1150+
if (threadInstance) {
1151+
dispatch({ channel, message, threadInstance, type: 'openThread' });
1152+
return;
11531153
}
1154-
});
1155-
dispatch({ channel, message, type: 'openThread' });
1154+
1155+
dispatch({
1156+
channel,
1157+
message,
1158+
threadInstance: {} as Thread<StreamChatGenerics>,
1159+
type: 'openThread',
1160+
});
1161+
1162+
client
1163+
.getThread(message.id, { reply_limit: DEFAULT_THREAD_PAGE_SIZE })
1164+
.then((t: Thread<StreamChatGenerics>) => {
1165+
t.registerSubscriptions();
1166+
dispatch({
1167+
channel,
1168+
message,
1169+
threadInstance: t,
1170+
type: 'openThread',
1171+
});
1172+
client.threads.addThread(t);
1173+
});
1174+
} else {
1175+
setQuotedMessage((current) => {
1176+
if (current?.parent_id !== message?.parent_id) {
1177+
return undefined;
1178+
} else {
1179+
return current;
1180+
}
1181+
});
1182+
dispatch({ channel, message, type: 'openThread' });
1183+
}
11561184
};
11571185

11581186
const closeThread = (event?: React.BaseSyntheticEvent) => {
@@ -1181,7 +1209,13 @@ const ChannelInner = <V extends CustomTrigger = CustomTrigger>(
11811209

11821210
const loadMoreThread = async (limit: number = DEFAULT_THREAD_PAGE_SIZE) => {
11831211
// FIXME: should prevent loading more, if state.thread.reply_count === channel.state.threads[parentID].length
1184-
if (state.threadLoadingMore || !state.thread || !state.threadHasMore) return;
1212+
if (
1213+
state.threadInstance ||
1214+
state.threadLoadingMore ||
1215+
!state.thread ||
1216+
!state.threadHasMore
1217+
)
1218+
return;
11851219

11861220
dispatch({ type: 'startLoadingThread' });
11871221
const parentId = state.thread.id;

src/components/Channel/__tests__/Channel.test.js

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -567,18 +567,65 @@ describe('Channel', () => {
567567
const { channel, chatClient } = await initClient();
568568
const threadMessage = messages[0];
569569
const hasThread = jest.fn();
570+
const hasThreadInstance = jest.fn();
571+
const mockThreadInstance = {
572+
threadInstanceMock: true,
573+
registerSubscriptions: jest.fn(),
574+
};
575+
const getThreadSpy = jest
576+
.spyOn(chatClient, 'getThread')
577+
.mockResolvedValueOnce(mockThreadInstance);
570578

571579
// this renders Channel, calls openThread from a child context consumer with a message,
572580
// and then calls hasThread with the thread id if it was set.
573-
await renderComponent({ channel, chatClient }, ({ openThread, thread }) => {
574-
if (!thread) {
575-
openThread(threadMessage, { preventDefault: () => null });
576-
} else {
577-
hasThread(thread.id);
578-
}
581+
await renderComponent(
582+
{ channel, chatClient },
583+
({ openThread, thread, threadInstance }) => {
584+
if (!thread) {
585+
openThread(threadMessage, { preventDefault: () => null });
586+
} else {
587+
hasThread(thread.id);
588+
hasThreadInstance(threadInstance);
589+
}
590+
},
591+
);
592+
593+
await waitFor(() => {
594+
expect(hasThread).toHaveBeenCalledWith(threadMessage.id);
595+
expect(getThreadSpy).not.toHaveBeenCalled();
596+
expect(hasThreadInstance).toHaveBeenCalledWith(undefined);
579597
});
598+
getThreadSpy.mockRestore();
599+
});
600+
601+
it('uses Thread instance when messageDraftsEnabled is true', async () => {
602+
const { channel, chatClient } = await initClient();
603+
const threadMessage = messages[0];
604+
const hasThreadInstance = jest.fn();
605+
const mockThreadInstance = {
606+
threadInstanceMock: true,
607+
registerSubscriptions: jest.fn(),
608+
};
609+
const spy = jest
610+
.spyOn(chatClient, 'getThread')
611+
.mockResolvedValueOnce(mockThreadInstance);
580612

581-
await waitFor(() => expect(hasThread).toHaveBeenCalledWith(threadMessage.id));
613+
await renderComponent(
614+
{ channel, chatClient, messageDraftsEnabled: true },
615+
({ openThread, thread, threadInstance }) => {
616+
if (!thread) {
617+
openThread(threadMessage, { preventDefault: () => null });
618+
} else {
619+
hasThreadInstance(threadInstance);
620+
}
621+
},
622+
);
623+
624+
await waitFor(() => {
625+
expect(hasThreadInstance).toHaveBeenCalledWith(mockThreadInstance);
626+
expect(mockThreadInstance.registerSubscriptions).toHaveBeenCalledWith();
627+
});
628+
spy.mockRestore();
582629
});
583630

584631
it('should be able to load more messages in a thread until reaching the end', async () => {

src/components/Channel/channelState.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import type {
1+
import {
22
Channel,
33
MessageResponse,
44
ChannelState as StreamChannelState,
5+
Thread,
56
} from 'stream-chat';
67

78
import type { ChannelState, StreamMessage } from '../../context/ChannelStateContext';
@@ -51,6 +52,7 @@ export type ChannelStateReducerAction =
5152
channel: Channel;
5253
message: StreamMessage;
5354
type: 'openThread';
55+
threadInstance?: Thread<StreamChatGenerics>;
5456
}
5557
| {
5658
error: Error;
@@ -91,6 +93,7 @@ export const makeChannelReducer =
9193
return {
9294
...state,
9395
thread: null,
96+
threadInstance: undefined,
9497
threadLoadingMore: false,
9598
threadMessages: [],
9699
};
@@ -198,11 +201,12 @@ export const makeChannelReducer =
198201
}
199202

200203
case 'openThread': {
201-
const { channel, message } = action;
204+
const { channel, message, threadInstance } = action;
202205
return {
203206
...state,
204207
thread: message,
205208
threadHasMore: true,
209+
threadInstance,
206210
threadMessages: message.id
207211
? { ...channel.state.threads }[message.id] || []
208212
: [],

src/components/Channel/hooks/useCreateChannelStateContext.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ export const useCreateChannelStateContext = (
4545
suppressAutoscroll,
4646
thread,
4747
threadHasMore,
48+
threadInstance,
4849
threadLoadingMore,
4950
threadMessages = [],
5051
videoAttachmentSizeHandler,
@@ -141,6 +142,7 @@ export const useCreateChannelStateContext = (
141142
suppressAutoscroll,
142143
thread,
143144
threadHasMore,
145+
threadInstance,
144146
threadLoadingMore,
145147
threadMessages,
146148
videoAttachmentSizeHandler,
@@ -178,6 +180,7 @@ export const useCreateChannelStateContext = (
178180
suppressAutoscroll,
179181
thread,
180182
threadHasMore,
183+
threadInstance,
181184
threadLoadingMore,
182185
threadMessagesLength,
183186
watcherCount,

src/components/Thread/Thread.tsx

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,15 @@ export type ThreadProps<V extends CustomTrigger = CustomTrigger> = {
5353
export const Thread = <V extends CustomTrigger = CustomTrigger>(
5454
props: ThreadProps<V>,
5555
) => {
56-
const { channel, channelConfig, thread } = useChannelStateContext('Thread');
57-
const threadInstance = useThreadContext();
56+
const {
57+
channel,
58+
channelConfig,
59+
thread,
60+
threadInstance: threadInstanceChannelCtx,
61+
} = useChannelStateContext<StreamChatGenerics>('Thread');
62+
const threadInstanceThreadCtx = useThreadContext();
5863

64+
const threadInstance = threadInstanceThreadCtx ?? threadInstanceChannelCtx;
5965
if ((!thread && !threadInstance) || channelConfig?.replies === false) return null;
6066

6167
// the wrapper ensures a key variable is set and the component recreates on thread switch
@@ -68,7 +74,11 @@ export const Thread = <V extends CustomTrigger = CustomTrigger>(
6874
);
6975
};
7076

71-
const selector = (nextValue: ThreadState) => ({
77+
const selector = <
78+
StreamChatGenerics extends DefaultStreamChatGenerics = DefaultStreamChatGenerics,
79+
>(
80+
nextValue: ThreadState<StreamChatGenerics>,
81+
) => ({
7282
isLoadingNext: nextValue.pagination.isLoadingNext,
7383
isLoadingPrev: nextValue.pagination.isLoadingPrev,
7484
parentMessage: nextValue.parentMessage,
@@ -91,13 +101,12 @@ const ThreadInner = <V extends CustomTrigger = CustomTrigger>(
91101
virtualized,
92102
} = props;
93103

94-
const threadInstance = useThreadContext();
95-
const { isLoadingNext, isLoadingPrev, parentMessage, replies } =
96-
useStateStore(threadInstance?.state, selector) ?? {};
104+
const threadInstanceThreadCtx = useThreadContext<StreamChatGenerics>();
97105

98106
const {
99107
thread,
100108
threadHasMore,
109+
threadInstance: threadInstanceChannelCtx,
101110
threadLoadingMore,
102111
threadMessages = [],
103112
threadSuppressAutoscroll,
@@ -112,6 +121,11 @@ const ThreadInner = <V extends CustomTrigger = CustomTrigger>(
112121
VirtualMessage,
113122
} = useComponentContext('Thread');
114123

124+
const threadInstance = threadInstanceThreadCtx ?? threadInstanceChannelCtx;
125+
126+
const { isLoadingNext, isLoadingPrev, parentMessage, replies } =
127+
useStateStore(threadInstance?.state, selector) ?? {};
128+
115129
const ThreadInput =
116130
PropInput ?? additionalMessageInputProps?.Input ?? ContextInput ?? MessageInputFlat;
117131

@@ -122,7 +136,7 @@ const ThreadInner = <V extends CustomTrigger = CustomTrigger>(
122136
const ThreadMessageList = virtualized ? VirtualizedMessageList : MessageList;
123137

124138
useEffect(() => {
125-
if (thread?.id && thread?.reply_count) {
139+
if (!threadInstance && thread?.id && thread?.reply_count) {
126140
// FIXME: integrators can customize channel query options but cannot customize channel.getReplies() options
127141
loadMoreThread();
128142
}

src/components/Thread/__tests__/Thread.test.js

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,12 @@ describe('Thread', () => {
330330
expect(channelActionContextMock.loadMoreThread).toHaveBeenCalledTimes(1);
331331
});
332332

333+
it('should not call the loadMoreThread callback on mount if the thread start has a non-zero reply count but threadInstance is provided', () => {
334+
renderComponent({ chatClient, channelStateOverrides: { threadInstance: {} } });
335+
336+
expect(channelActionContextMock.loadMoreThread).not.toHaveBeenCalled();
337+
});
338+
333339
it('should render null if replies is disabled', async () => {
334340
const client = await getTestClientWithUser();
335341
const ch = generateChannel({ getConfig: () => ({ replies: false }) });

src/components/Threads/ThreadContext.tsx

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@ import { Channel } from '../../components';
44

55
import type { PropsWithChildren } from 'react';
66
import type { Thread } from 'stream-chat';
7+
import type { DefaultStreamChatGenerics } from '../../types';
78

89
export type ThreadContextValue = Thread | undefined;
910

1011
export const ThreadContext = createContext<ThreadContextValue>(undefined);
1112

12-
export const useThreadContext = () => {
13+
export const useThreadContext = <
14+
StreamChatGenerics extends DefaultStreamChatGenerics = DefaultStreamChatGenerics,
15+
>() => {
1316
const thread = useContext(ThreadContext);
1417

15-
return thread ?? undefined;
18+
return (thread as unknown as Thread<StreamChatGenerics>) ?? undefined;
1619
};
1720

1821
export const ThreadProvider = ({

src/context/ChannelStateContext.tsx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import type {
99
MessageResponse,
1010
Mute,
1111
ChannelState as StreamChannelState,
12+
Thread,
1213
} from 'stream-chat';
1314

1415
import type {
@@ -53,6 +54,7 @@ export type ChannelState = {
5354
read?: StreamChannelState<StreamChatGenerics>['read'];
5455
thread?: StreamMessage<StreamChatGenerics> | null;
5556
threadHasMore?: boolean;
57+
threadInstance?: Thread<StreamChatGenerics>;
5658
threadLoadingMore?: boolean;
5759
threadMessages?: StreamMessage[];
5860
threadSuppressAutoscroll?: boolean;

0 commit comments

Comments
 (0)