Skip to content

Commit 7af53a0

Browse files
committed
feat: use Thread class instance to manage Thread component state when drafts enabled
1 parent c19ea09 commit 7af53a0

File tree

8 files changed

+141
-28
lines changed

8 files changed

+141
-28
lines changed

src/components/Channel/Channel.tsx

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ import type {
8282
MessageResponse,
8383
SendMessageAPIResponse,
8484
StreamChat,
85+
Thread,
8586
UpdatedMessage,
8687
UserResponse,
8788
} from 'stream-chat';
@@ -1177,14 +1178,42 @@ const ChannelInner = <
11771178
event?: React.BaseSyntheticEvent,
11781179
) => {
11791180
event?.preventDefault();
1180-
setQuotedMessage((current) => {
1181-
if (current?.parent_id !== message?.parent_id) {
1182-
return undefined;
1183-
} else {
1184-
return current;
1181+
if (messageDraftsEnabled) {
1182+
let threadInstance = client.threads.threadsById[message.id];
1183+
if (threadInstance) {
1184+
dispatch({ channel, message, threadInstance, type: 'openThread' });
1185+
return;
11851186
}
1186-
});
1187-
dispatch({ channel, message, type: 'openThread' });
1187+
1188+
dispatch({
1189+
channel,
1190+
message,
1191+
threadInstance: {} as Thread<StreamChatGenerics>,
1192+
type: 'openThread',
1193+
});
1194+
1195+
client
1196+
.getThread(message.id, { reply_limit: DEFAULT_THREAD_PAGE_SIZE })
1197+
.then((t: Thread<StreamChatGenerics>) => {
1198+
t.registerSubscriptions();
1199+
dispatch({
1200+
channel,
1201+
message,
1202+
threadInstance: t,
1203+
type: 'openThread',
1204+
});
1205+
client.threads.addThread(t);
1206+
});
1207+
} else {
1208+
setQuotedMessage((current) => {
1209+
if (current?.parent_id !== message?.parent_id) {
1210+
return undefined;
1211+
} else {
1212+
return current;
1213+
}
1214+
});
1215+
dispatch({ channel, message, type: 'openThread' });
1216+
}
11881217
};
11891218

11901219
const closeThread = (event?: React.BaseSyntheticEvent) => {
@@ -1215,7 +1244,13 @@ const ChannelInner = <
12151244

12161245
const loadMoreThread = async (limit: number = DEFAULT_THREAD_PAGE_SIZE) => {
12171246
// FIXME: should prevent loading more, if state.thread.reply_count === channel.state.threads[parentID].length
1218-
if (state.threadLoadingMore || !state.thread || !state.threadHasMore) return;
1247+
if (
1248+
state.threadInstance ||
1249+
state.threadLoadingMore ||
1250+
!state.thread ||
1251+
!state.threadHasMore
1252+
)
1253+
return;
12191254

12201255
dispatch({ type: 'startLoadingThread' });
12211256
const parentId = state.thread.id;

src/components/Channel/__tests__/Channel.test.js

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -567,18 +567,65 @@ describe('Channel', () => {
567567
const { channel, chatClient } = await initClient();
568568
const threadMessage = messages[0];
569569
const hasThread = jest.fn();
570+
const hasThreadInstance = jest.fn();
571+
const mockThreadInstance = {
572+
threadInstanceMock: true,
573+
registerSubscriptions: jest.fn(),
574+
};
575+
const getThreadSpy = jest
576+
.spyOn(chatClient, 'getThread')
577+
.mockResolvedValueOnce(mockThreadInstance);
570578

571579
// this renders Channel, calls openThread from a child context consumer with a message,
572580
// and then calls hasThread with the thread id if it was set.
573-
await renderComponent({ channel, chatClient }, ({ openThread, thread }) => {
574-
if (!thread) {
575-
openThread(threadMessage, { preventDefault: () => null });
576-
} else {
577-
hasThread(thread.id);
578-
}
581+
await renderComponent(
582+
{ channel, chatClient },
583+
({ openThread, thread, threadInstance }) => {
584+
if (!thread) {
585+
openThread(threadMessage, { preventDefault: () => null });
586+
} else {
587+
hasThread(thread.id);
588+
hasThreadInstance(threadInstance);
589+
}
590+
},
591+
);
592+
593+
await waitFor(() => {
594+
expect(hasThread).toHaveBeenCalledWith(threadMessage.id);
595+
expect(getThreadSpy).not.toHaveBeenCalled();
596+
expect(hasThreadInstance).toHaveBeenCalledWith(undefined);
579597
});
598+
getThreadSpy.mockRestore();
599+
});
600+
601+
it('uses Thread instance when messageDraftsEnabled is true', async () => {
602+
const { channel, chatClient } = await initClient();
603+
const threadMessage = messages[0];
604+
const hasThreadInstance = jest.fn();
605+
const mockThreadInstance = {
606+
threadInstanceMock: true,
607+
registerSubscriptions: jest.fn(),
608+
};
609+
const spy = jest
610+
.spyOn(chatClient, 'getThread')
611+
.mockResolvedValueOnce(mockThreadInstance);
580612

581-
await waitFor(() => expect(hasThread).toHaveBeenCalledWith(threadMessage.id));
613+
await renderComponent(
614+
{ channel, chatClient, messageDraftsEnabled: true },
615+
({ openThread, thread, threadInstance }) => {
616+
if (!thread) {
617+
openThread(threadMessage, { preventDefault: () => null });
618+
} else {
619+
hasThreadInstance(threadInstance);
620+
}
621+
},
622+
);
623+
624+
await waitFor(() => {
625+
expect(hasThreadInstance).toHaveBeenCalledWith(mockThreadInstance);
626+
expect(mockThreadInstance.registerSubscriptions).toHaveBeenCalledWith();
627+
});
628+
spy.mockRestore();
582629
});
583630

584631
it('should be able to load more messages in a thread until reaching the end', async () => {

src/components/Channel/channelState.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import type {
1+
import {
22
Channel,
33
MessageResponse,
44
ChannelState as StreamChannelState,
5+
Thread,
56
} from 'stream-chat';
67

78
import type { ChannelState, StreamMessage } from '../../context/ChannelStateContext';
@@ -57,6 +58,7 @@ export type ChannelStateReducerAction<
5758
channel: Channel<StreamChatGenerics>;
5859
message: StreamMessage<StreamChatGenerics>;
5960
type: 'openThread';
61+
threadInstance?: Thread<StreamChatGenerics>;
6062
}
6163
| {
6264
error: Error;
@@ -101,6 +103,7 @@ export const makeChannelReducer =
101103
return {
102104
...state,
103105
thread: null,
106+
threadInstance: undefined,
104107
threadLoadingMore: false,
105108
threadMessages: [],
106109
};
@@ -208,11 +211,12 @@ export const makeChannelReducer =
208211
}
209212

210213
case 'openThread': {
211-
const { channel, message } = action;
214+
const { channel, message, threadInstance } = action;
212215
return {
213216
...state,
214217
thread: message,
215218
threadHasMore: true,
219+
threadInstance,
216220
threadMessages: message.id
217221
? { ...channel.state.threads }[message.id] || []
218222
: [],

src/components/Channel/hooks/useCreateChannelStateContext.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ export const useCreateChannelStateContext = <
4949
suppressAutoscroll,
5050
thread,
5151
threadHasMore,
52+
threadInstance,
5253
threadLoadingMore,
5354
threadMessages = [],
5455
videoAttachmentSizeHandler,
@@ -145,6 +146,7 @@ export const useCreateChannelStateContext = <
145146
suppressAutoscroll,
146147
thread,
147148
threadHasMore,
149+
threadInstance,
148150
threadLoadingMore,
149151
threadMessages,
150152
videoAttachmentSizeHandler,
@@ -182,6 +184,7 @@ export const useCreateChannelStateContext = <
182184
suppressAutoscroll,
183185
thread,
184186
threadHasMore,
187+
threadInstance,
185188
threadLoadingMore,
186189
threadMessagesLength,
187190
watcherCount,

src/components/Thread/Thread.tsx

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,15 @@ export const Thread = <
6262
>(
6363
props: ThreadProps<StreamChatGenerics, V>,
6464
) => {
65-
const { channel, channelConfig, thread } =
66-
useChannelStateContext<StreamChatGenerics>('Thread');
67-
const threadInstance = useThreadContext();
65+
const {
66+
channel,
67+
channelConfig,
68+
thread,
69+
threadInstance: threadInstanceChannelCtx,
70+
} = useChannelStateContext<StreamChatGenerics>('Thread');
71+
const threadInstanceThreadCtx = useThreadContext();
6872

73+
const threadInstance = threadInstanceThreadCtx ?? threadInstanceChannelCtx;
6974
if ((!thread && !threadInstance) || channelConfig?.replies === false) return null;
7075

7176
// the wrapper ensures a key variable is set and the component recreates on thread switch
@@ -78,7 +83,11 @@ export const Thread = <
7883
);
7984
};
8085

81-
const selector = (nextValue: ThreadState) => ({
86+
const selector = <
87+
StreamChatGenerics extends DefaultStreamChatGenerics = DefaultStreamChatGenerics,
88+
>(
89+
nextValue: ThreadState<StreamChatGenerics>,
90+
) => ({
8291
isLoadingNext: nextValue.pagination.isLoadingNext,
8392
isLoadingPrev: nextValue.pagination.isLoadingPrev,
8493
parentMessage: nextValue.parentMessage,
@@ -104,13 +113,12 @@ const ThreadInner = <
104113
virtualized,
105114
} = props;
106115

107-
const threadInstance = useThreadContext();
108-
const { isLoadingNext, isLoadingPrev, parentMessage, replies } =
109-
useStateStore(threadInstance?.state, selector) ?? {};
116+
const threadInstanceThreadCtx = useThreadContext<StreamChatGenerics>();
110117

111118
const {
112119
thread,
113120
threadHasMore,
121+
threadInstance: threadInstanceChannelCtx,
114122
threadLoadingMore,
115123
threadMessages = [],
116124
threadSuppressAutoscroll,
@@ -126,6 +134,11 @@ const ThreadInner = <
126134
VirtualMessage,
127135
} = useComponentContext<StreamChatGenerics>('Thread');
128136

137+
const threadInstance = threadInstanceThreadCtx ?? threadInstanceChannelCtx;
138+
139+
const { isLoadingNext, isLoadingPrev, parentMessage, replies } =
140+
useStateStore(threadInstance?.state, selector) ?? {};
141+
129142
const ThreadInput =
130143
PropInput ?? additionalMessageInputProps?.Input ?? ContextInput ?? MessageInputFlat;
131144

@@ -136,7 +149,7 @@ const ThreadInner = <
136149
const ThreadMessageList = virtualized ? VirtualizedMessageList : MessageList;
137150

138151
useEffect(() => {
139-
if (thread?.id && thread?.reply_count) {
152+
if (!threadInstance && thread?.id && thread?.reply_count) {
140153
// FIXME: integrators can customize channel query options but cannot customize channel.getReplies() options
141154
loadMoreThread();
142155
}

src/components/Thread/__tests__/Thread.test.js

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,12 @@ describe('Thread', () => {
330330
expect(channelActionContextMock.loadMoreThread).toHaveBeenCalledTimes(1);
331331
});
332332

333+
it('should not call the loadMoreThread callback on mount if the thread start has a non-zero reply count but threadInstance is provided', () => {
334+
renderComponent({ chatClient, channelStateOverrides: { threadInstance: {} } });
335+
336+
expect(channelActionContextMock.loadMoreThread).not.toHaveBeenCalled();
337+
});
338+
333339
it('should render null if replies is disabled', async () => {
334340
const client = await getTestClientWithUser();
335341
const ch = generateChannel({ getConfig: () => ({ replies: false }) });

src/components/Threads/ThreadContext.tsx

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@ import React, { createContext, useContext } from 'react';
33
import { Channel } from '../../components';
44

55
import type { PropsWithChildren } from 'react';
6-
import { Thread } from 'stream-chat';
6+
import type { Thread } from 'stream-chat';
7+
import type { DefaultStreamChatGenerics } from '../../types';
78

89
export type ThreadContextValue = Thread | undefined;
910

1011
export const ThreadContext = createContext<ThreadContextValue>(undefined);
1112

12-
export const useThreadContext = () => {
13+
export const useThreadContext = <
14+
StreamChatGenerics extends DefaultStreamChatGenerics = DefaultStreamChatGenerics,
15+
>() => {
1316
const thread = useContext(ThreadContext);
1417

15-
return thread ?? undefined;
18+
return (thread as unknown as Thread<StreamChatGenerics>) ?? undefined;
1619
};
1720

1821
export const ThreadProvider = ({

src/context/ChannelStateContext.tsx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import type {
77
MessageResponse,
88
Mute,
99
ChannelState as StreamChannelState,
10+
Thread,
1011
} from 'stream-chat';
1112

1213
import type {
@@ -50,6 +51,7 @@ export type ChannelState<
5051
read?: StreamChannelState<StreamChatGenerics>['read'];
5152
thread?: StreamMessage<StreamChatGenerics> | null;
5253
threadHasMore?: boolean;
54+
threadInstance?: Thread<StreamChatGenerics>;
5355
threadLoadingMore?: boolean;
5456
threadMessages?: StreamMessage<StreamChatGenerics>[];
5557
threadSuppressAutoscroll?: boolean;

0 commit comments

Comments
 (0)