Skip to content

Commit ea58e48

Browse files
committed
feat(chat): support rehype plugins and regenerate provider config
1 parent c0fc68f commit ea58e48

File tree

6 files changed

+43
-9
lines changed

6 files changed

+43
-9
lines changed

src/chat/button/index.tsx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ export default function Button({ type = 'default', className, children, ...rest
1717
useEffect(() => {
1818
if (!document.querySelector(`.${GLOBAL_GRADIENT_CLASSNAME}`)) {
1919
const div = document.createElement('div');
20+
div.style.setProperty('width', '0');
21+
div.style.setProperty('height', '0');
2022
div.className = GLOBAL_GRADIENT_CLASSNAME;
2123
div.innerHTML = renderToString(
2224
<svg width={0} height={0}>

src/chat/content/index.tsx

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ const Content = forwardRef<IContentRef, IContentProps>(function (
2626
{ data, placeholder, robotIcon = true, scrollable = true, onRegenerate, onStop },
2727
forwardedRef
2828
) {
29-
const { maxRegenerateCount, copy } = useContext();
29+
const { maxRegenerateCount, copy, regenerate } = useContext();
3030
const containerRef = useRef<HTMLDivElement>(null);
3131

3232
const [isStickyAtBottom, setIsStickyAtBottom] = useState<boolean>(true);
@@ -105,14 +105,18 @@ const Content = forwardRef<IContentRef, IContentProps>(function (
105105
{dataValid ? (
106106
<div className="dtc__aigc__content__inner__holder">
107107
{data.map((row, idx) => {
108+
const defaultRegenerate =
109+
idx === data.length - 1 && row.messages.length < maxRegenerateCount;
108110
return (
109111
<React.Fragment key={row.id}>
110112
<Prompt data={row} />
111113
<Message
114+
prompt={row}
112115
data={row.messages}
113116
regenerate={
114-
idx === data.length - 1 &&
115-
row.messages.length < maxRegenerateCount
117+
typeof regenerate === 'function'
118+
? regenerate(row, idx, data)
119+
: regenerate ?? defaultRegenerate
116120
}
117121
copy={copy}
118122
onRegenerate={(message) => onRegenerate?.(message, row)}

src/chat/index.tsx

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,26 @@ function Chat({
2525
chat,
2626
components,
2727
maxRegenerateCount = DEFAULT_MAX_REGENERATE_COUNT,
28+
regenerate,
2829
copy,
2930
messageIcons,
31+
rehypePlugins,
32+
remarkPlugins,
3033
children,
3134
}: PropsWithChildren<ChatProviderConfig>) {
3235
return (
33-
<context.Provider value={{ chat, components, maxRegenerateCount, copy, messageIcons }}>
36+
<context.Provider
37+
value={{
38+
chat,
39+
components,
40+
maxRegenerateCount,
41+
copy,
42+
messageIcons,
43+
regenerate,
44+
rehypePlugins,
45+
remarkPlugins,
46+
}}
47+
>
3448
{children}
3549
</context.Provider>
3650
);

src/chat/message/index.tsx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ export default function Message({
3939
onLazyRendered,
4040
}: IMessageProps) {
4141
const divRef = useRef<HTMLDivElement>(null);
42-
const { components = {}, messageIcons } = useContext();
42+
const { components = {}, messageIcons, rehypePlugins, remarkPlugins } = useContext();
4343

4444
// 当前 Message 的懒加载,是否已经加载过
4545
const [lazyRendered, setLazyRendered] = useState(false);
@@ -144,6 +144,8 @@ export default function Message({
144144
<Markdown
145145
typing={typing}
146146
components={composedComponents}
147+
rehypePlugins={rehypePlugins}
148+
remarkPlugins={remarkPlugins}
147149
onMount={() => {
148150
mountCallback.current();
149151
}}

src/chat/useChat.ts

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,17 +169,20 @@ export default function useChat<
169169
function _updateMessage(
170170
promptId: Id,
171171
messageId: Id,
172-
predicate: (message: Message) => Message
172+
predicate: (message: Message) => Message,
173+
triggerRerender?: boolean
173174
): void;
174175
function _updateMessage(
175176
promptId: Id,
176177
messageId: Id,
177-
data: Partial<Omit<MessageProperties, 'id'>>
178+
data: Partial<Omit<MessageProperties, 'id'>>,
179+
triggerRerender?: boolean
178180
): void;
179181
function _updateMessage(
180182
promptId: Id,
181183
messageId: Id,
182-
dataOrPredicate: Partial<Omit<MessageProperties, 'id'>> | ((message: Message) => Message)
184+
dataOrPredicate: Partial<Omit<MessageProperties, 'id'>> | ((message: Message) => Message),
185+
triggerRerender?: boolean
183186
) {
184187
if (!state.current) return;
185188
state.current = produce(state.current, (draft) => {
@@ -192,7 +195,9 @@ export default function useChat<
192195
Object.assign(message, dataOrPredicate);
193196
}
194197
});
195-
update();
198+
if (triggerRerender !== false) {
199+
update();
200+
}
196201
}
197202

198203
function _getMessage(promptId: Id, messageId: Id) {

src/chat/useContext.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import React from 'react';
22
import { type Components } from 'react-markdown';
3+
import { type ReactMarkdownOptions } from 'react-markdown/lib/react-markdown';
34

45
import { type ICopyProps } from '../copy';
56
import type { Message, Prompt } from './entity';
@@ -22,8 +23,14 @@ export interface IChatContext {
2223
* 重新回答的最大次数
2324
*/
2425
maxRegenerateCount: number;
26+
/**
27+
* 是否支持重新生成
28+
*/
29+
regenerate?: boolean | ((prompt: Prompt, index: number, array: Prompt[]) => boolean);
2530
copy?: boolean | CopyOptions;
2631
messageIcons?: React.ReactNode | ((record: Message, prompt: Prompt) => React.ReactNode);
32+
rehypePlugins?: ReactMarkdownOptions['rehypePlugins'];
33+
remarkPlugins?: ReactMarkdownOptions['remarkPlugins'];
2734
}
2835

2936
export const context = React.createContext<IChatContext>({

0 commit comments

Comments
 (0)