Skip to content

Commit c8dc8d7

Browse files
committed
allow multiple generations at the same time
1 parent 518e077 commit c8dc8d7

File tree

6 files changed

+77
-50
lines changed

6 files changed

+77
-50
lines changed

examples/server/webui/src/components/ChatMessage.tsx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@ export default function ChatMessage({
2020
msg,
2121
id,
2222
scrollToBottom,
23+
isPending,
2324
}: {
2425
msg: Message | PendingMessage;
2526
id?: string;
2627
scrollToBottom: (requiresNearBottom: boolean) => void;
28+
isPending?: boolean;
2729
}) {
2830
const { viewingConversation, replaceMessageAndGenerate, config } =
2931
useAppContext();
@@ -42,8 +44,6 @@ export default function ChatMessage({
4244
[msg.timings]
4345
);
4446

45-
const isPending: boolean = !!(msg as PendingMessage).convId;
46-
4747
// for reasoning model, we split the message into content and thought
4848
// TODO: implement this as remark/rehype plugin in the future
4949
const { content, thought, isThinking }: SplitMessage = useMemo(() => {

examples/server/webui/src/components/ChatScreen.tsx

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,23 @@ import { useAppContext } from '../utils/app.context';
33
import StorageUtils from '../utils/storage';
44
import { useNavigate } from 'react-router';
55
import ChatMessage from './ChatMessage';
6+
import { PendingMessage } from '../utils/types';
67

78
export default function ChatScreen() {
89
const {
910
viewingConversation,
1011
sendMessage,
1112
isGenerating,
1213
stopGenerating,
13-
pendingMessage,
14+
pendingMessages,
1415
} = useAppContext();
1516
const [inputMsg, setInputMsg] = useState('');
1617
const containerRef = useRef<HTMLDivElement>(null);
1718
const navigate = useNavigate();
1819

20+
const currConvId = viewingConversation?.id ?? '';
21+
const pendingMsg: PendingMessage | undefined = pendingMessages[currConvId];
22+
1923
const scrollToBottom = (requiresNearBottom: boolean) => {
2024
if (!containerRef.current) return;
2125
const msgListElem = containerRef.current;
@@ -70,14 +74,14 @@ export default function ChatScreen() {
7074
<ChatMessage key={msg.id} msg={msg} scrollToBottom={scrollToBottom} />
7175
))}
7276

73-
{pendingMessage !== null &&
74-
pendingMessage.convId === viewingConversation?.id && (
75-
<ChatMessage
76-
msg={pendingMessage}
77-
scrollToBottom={scrollToBottom}
78-
id="pending-msg"
79-
/>
80-
)}
77+
{pendingMsg && (
78+
<ChatMessage
79+
msg={pendingMsg}
80+
scrollToBottom={scrollToBottom}
81+
isPending
82+
id="pending-msg"
83+
/>
84+
)}
8185
</div>
8286

8387
{/* chat input */}
@@ -97,8 +101,11 @@ export default function ChatScreen() {
97101
id="msg-input"
98102
dir="auto"
99103
></textarea>
100-
{isGenerating ? (
101-
<button className="btn btn-neutral ml-2" onClick={stopGenerating}>
104+
{isGenerating(currConvId) ? (
105+
<button
106+
className="btn btn-neutral ml-2"
107+
onClick={() => stopGenerating(currConvId)}
108+
>
102109
Stop
103110
</button>
104111
) : (

examples/server/webui/src/components/Header.tsx

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@ export default function Header() {
2727
}, [selectedTheme]);
2828

2929
const { isGenerating, viewingConversation } = useAppContext();
30+
const isCurrConvGenerating = isGenerating(viewingConversation?.id ?? '');
3031

3132
const removeConversation = () => {
32-
if (isGenerating || !viewingConversation) return;
33+
if (isCurrConvGenerating || !viewingConversation) return;
3334
const convId = viewingConversation.id;
3435
if (window.confirm('Are you sure to delete this conversation?')) {
3536
StorageUtils.remove(convId);
@@ -38,7 +39,7 @@ export default function Header() {
3839
};
3940

4041
const downloadConversation = () => {
41-
if (isGenerating || !viewingConversation) return;
42+
if (isCurrConvGenerating || !viewingConversation) return;
4243
const convId = viewingConversation.id;
4344
const conversationJson = JSON.stringify(viewingConversation, null, 2);
4445
const blob = new Blob([conversationJson], { type: 'application/json' });
@@ -81,7 +82,7 @@ export default function Header() {
8182
tabIndex={0}
8283
role="button"
8384
className="btn m-1"
84-
disabled={isGenerating}
85+
disabled={isCurrConvGenerating}
8586
>
8687
<svg
8788
xmlns="http://www.w3.org/2000/svg"
@@ -108,11 +109,7 @@ export default function Header() {
108109
</ul>
109110
</div>
110111
<div className="tooltip tooltip-bottom" data-tip="Settings">
111-
<button
112-
className="btn"
113-
disabled={isGenerating}
114-
onClick={() => setShowSettingDialog(true)}
115-
>
112+
<button className="btn" onClick={() => setShowSettingDialog(true)}>
116113
{/* settings button */}
117114
<svg
118115
xmlns="http://www.w3.org/2000/svg"

examples/server/webui/src/components/Sidebar.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import { classNames } from '../utils/misc';
33
import { Conversation } from '../utils/types';
44
import StorageUtils from '../utils/storage';
55
import { useNavigate, useParams } from 'react-router';
6+
import { useAppContext } from '../utils/app.context';
67

78
export default function Sidebar() {
89
const params = useParams();

examples/server/webui/src/utils/app.context.tsx

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@ import { BASE_URL, CONFIG_DEFAULT, isDev } from '../Config';
1010
import { matchPath, useLocation } from 'react-router';
1111

1212
interface AppContextValue {
13-
isGenerating: boolean;
1413
viewingConversation: Conversation | null;
15-
pendingMessage: PendingMessage | null;
14+
pendingMessages: Record<Conversation['id'], PendingMessage>;
15+
isGenerating: (convId: string) => boolean;
1616
sendMessage: (
1717
convId: string,
1818
content: string,
1919
onChunk?: CallbackGeneratedChunk
2020
) => Promise<boolean>;
21-
stopGenerating: () => void;
21+
stopGenerating: (convId: string) => void;
2222
replaceMessageAndGenerate: (
2323
convId: string,
2424
origMsgId: Message['id'],
@@ -45,13 +45,14 @@ export const AppContextProvider = ({
4545
const params = matchPath('/chat/:convId', pathname);
4646
const convId = params?.params?.convId;
4747

48-
const [isGenerating, setIsGenerating] = useState(false);
4948
const [viewingConversation, setViewingConversation] =
5049
useState<Conversation | null>(null);
51-
const [pendingMessage, setPendingMessage] = useState<PendingMessage | null>(
52-
null
53-
);
54-
const [abortController, setAbortController] = useState(new AbortController());
50+
const [pendingMessages, setPendingMessages] = useState<
51+
Record<Conversation['id'], PendingMessage>
52+
>({});
53+
const [aborts, setAborts] = useState<
54+
Record<Conversation['id'], AbortController>
55+
>({});
5556
const [config, setConfig] = useState(StorageUtils.getConfig());
5657

5758
useEffect(() => {
@@ -66,11 +67,41 @@ export const AppContextProvider = ({
6667
};
6768
}, [convId]);
6869

70+
const setPending = (convId: string, pendingMsg: PendingMessage | null) => {
71+
// if pendingMsg is null, remove the key from the object
72+
if (!pendingMsg) {
73+
setPendingMessages((prev) => {
74+
const newState = { ...prev };
75+
delete newState[convId];
76+
return newState;
77+
});
78+
} else {
79+
setPendingMessages((prev) => ({ ...prev, [convId]: pendingMsg }));
80+
}
81+
};
82+
83+
const setAbort = (convId: string, controller: AbortController | null) => {
84+
if (!controller) {
85+
setAborts((prev) => {
86+
const newState = { ...prev };
87+
delete newState[convId];
88+
return newState;
89+
});
90+
} else {
91+
setAborts((prev) => ({ ...prev, [convId]: controller }));
92+
}
93+
};
94+
95+
////////////////////////////////////////////////////////////////////////
96+
// public functions
97+
98+
const isGenerating = (convId: string) => !!pendingMessages[convId];
99+
69100
const generateMessage = async (
70101
convId: string,
71102
onChunk?: CallbackGeneratedChunk
72103
) => {
73-
if (isGenerating) return;
104+
if (isGenerating(convId)) return;
74105

75106
const config = StorageUtils.getConfig();
76107
const currConversation = StorageUtils.getOneConversation(convId);
@@ -79,16 +110,14 @@ export const AppContextProvider = ({
79110
}
80111

81112
const abortController = new AbortController();
82-
setIsGenerating(true);
83-
setAbortController(abortController);
113+
setAbort(convId, abortController);
84114

85115
let pendingMsg: PendingMessage = {
86-
convId,
87116
id: Date.now() + 1,
88117
role: 'assistant',
89118
content: null,
90119
};
91-
setPendingMessage(pendingMsg);
120+
setPending(convId, pendingMsg);
92121

93122
try {
94123
// prepare messages for API
@@ -157,7 +186,6 @@ export const AppContextProvider = ({
157186
const lastContent = pendingMsg.content || '';
158187
if (addedContent) {
159188
pendingMsg = {
160-
convId,
161189
id: pendingMsg.id,
162190
role: 'assistant',
163191
content: lastContent + addedContent,
@@ -173,18 +201,15 @@ export const AppContextProvider = ({
173201
predicted_ms: timings.predicted_ms,
174202
};
175203
}
176-
setPendingMessage(pendingMsg);
204+
setPending(convId, pendingMsg);
177205
onChunk?.();
178206
}
179207
} catch (err) {
180-
console.error(err);
181-
setPendingMessage(null);
182-
setIsGenerating(false);
208+
setPending(convId, null);
183209
if ((err as Error).name === 'AbortError') {
184210
// user stopped the generation via stopGeneration() function
185211
// we can safely ignore this error
186212
} else {
187-
setIsGenerating(false);
188213
console.error(err);
189214
// eslint-disable-next-line @typescript-eslint/no-explicit-any
190215
alert((err as any)?.message ?? 'Unknown error');
@@ -200,8 +225,7 @@ export const AppContextProvider = ({
200225
timings: pendingMsg.timings,
201226
});
202227
}
203-
setPendingMessage(null);
204-
setIsGenerating(false);
228+
setPending(convId, null);
205229
onChunk?.(); // trigger scroll to bottom
206230
};
207231

@@ -210,7 +234,7 @@ export const AppContextProvider = ({
210234
content: string,
211235
onChunk?: CallbackGeneratedChunk
212236
): Promise<boolean> => {
213-
if (isGenerating || content.trim().length === 0) return false;
237+
if (isGenerating(convId) || content.trim().length === 0) return false;
214238

215239
StorageUtils.appendMsg(convId, {
216240
id: Date.now(),
@@ -228,10 +252,9 @@ export const AppContextProvider = ({
228252
return false;
229253
};
230254

231-
const stopGenerating = () => {
232-
setIsGenerating(false);
233-
setPendingMessage(null);
234-
abortController.abort();
255+
const stopGenerating = (convId: string) => {
256+
setPending(convId, null);
257+
aborts[convId]?.abort();
235258
};
236259

237260
// if content is undefined, we remove last assistant message
@@ -241,7 +264,7 @@ export const AppContextProvider = ({
241264
content?: string,
242265
onChunk?: CallbackGeneratedChunk
243266
) => {
244-
if (isGenerating) return;
267+
if (isGenerating(convId)) return;
245268

246269
StorageUtils.filterAndKeepMsgs(convId, (msg) => msg.id < origMsgId);
247270
if (content) {
@@ -265,7 +288,7 @@ export const AppContextProvider = ({
265288
value={{
266289
isGenerating,
267290
viewingConversation,
268-
pendingMessage,
291+
pendingMessages,
269292
sendMessage,
270293
stopGenerating,
271294
replaceMessageAndGenerate,

examples/server/webui/src/utils/types.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,5 @@ export interface Conversation {
2121
}
2222

2323
export type PendingMessage = Omit<Message, 'content'> & {
24-
convId: string;
2524
content: string | null;
2625
};

0 commit comments

Comments
 (0)