Skip to content

Commit 7a7f68c

Browse files
committed
feat: Enable dynamic AI model selection and tool filtering in chat interactions.
1 parent 11f70e8 commit 7a7f68c

File tree

8 files changed

+88
-8
lines changed

8 files changed

+88
-8
lines changed

apps/nextjs-example/app/api/chat/route.ts

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,28 @@
11
import { createOpenAI } from "@ai-sdk/openai";
22
import { streamText, convertToModelMessages, type UIMessage, smoothStream } from "ai";
33
import { weatherTool } from "../../../lib/tools/weather-tool";
4+
import { getModel, getDisabledToolsFiter } from "@creatorem/ai-chat/server";
45

56
const groq = createOpenAI({
67
apiKey: process.env.GROQ_API_KEY ?? '',
78
baseURL: 'https://api.groq.com/openai/v1',
89
});
910

1011
export async function POST(req: Request) {
11-
const { messages }: { messages: UIMessage[] } = await req.json();
12+
const { messages, ...body }: { messages: UIMessage[] } = await req.json();
13+
14+
const model = groq(getModel(body) ?? 'llama-3.3-70b-versatile')
15+
const toolsFilter = getDisabledToolsFiter(body);
1216

1317
const result = streamText({
1418
// model: groq('llama-3.3-70b-versatile'),
15-
model: groq('meta-llama/llama-4-scout-17b-16e-instruct'),
19+
// model: groq('meta-llama/llama-4-scout-17b-16e-instruct'),
20+
// model: groq('meta-llama/llama-4-scout-17b-16e-instruct'),
21+
model,
1622
messages: await convertToModelMessages(messages),
17-
tools: {
23+
tools: toolsFilter({
1824
weather: weatherTool,
19-
},
25+
}),
2026
experimental_transform: smoothStream({
2127
delayInMs: 30, // optional: defaults to 10ms
2228
chunking: 'line', // optional: defaults to 'word'

packages-test-3/ai-chat/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"./hook-types": "./src/hook-types.ts",
2424
"./utils": "./src/utils/index.ts",
2525
"./hooks": "./src/hooks/index.ts",
26+
"./server": "./src/server/index.ts",
2627
"./types/*": "./src/types/*.ts"
2728
},
2829
"main": "./dist/index.js",

packages-test-3/ai-chat/src/ai-provider.tsx

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ export type AiContextType = {
3838
toolkit?: Toolkit | undefined;
3939
callSettings?: LanguageModelV1CallSettings | undefined;
4040
config?: LanguageModelConfig | undefined;
41+
selectedModel: string | null;
42+
disabledTools: string[];
4143
chatOptions?: Omit<UseChatOptions<Thread['messages'][0]> & ChatInit<Thread['messages'][0]>, 'id' | 'transport'> & {
4244
transportOptions?: HttpChatTransportInitOptions<Thread['messages'][0]>
4345
}
@@ -68,11 +70,16 @@ export const useAiEvent = <TEvent extends keyof AiChatEvents>(name: TEvent, p: (
6870
}, [eventHandler, name, p])
6971
};
7072

71-
export function AiProvider({ children, ...value }: { children: React.ReactNode } & Omit<AiContextType, 'eventHandler'>) {
73+
export function AiProvider({ children, ...value }: { children: React.ReactNode } & Omit<AiContextType, 'eventHandler' | 'selectedModel' | 'disabledTools'> & Partial<Pick<AiContextType, 'selectedModel' | 'disabledTools'>>) {
7274
// Create store once
7375
const storeRef = useRef<StoreApi<AiContextType> | null>(null);
7476
if (storeRef.current === null) {
75-
storeRef.current = createStore<AiContextType>(() => ({ ...value, eventHandler: new AiChatEventHandler() }));
77+
storeRef.current = createStore<AiContextType>(() => ({
78+
...value,
79+
selectedModel: value.selectedModel ?? null,
80+
disabledTools: value.disabledTools ?? [],
81+
eventHandler: new AiChatEventHandler(),
82+
}));
7683
}
7784

7885
return <AiContextStoreCtx.Provider value={storeRef.current}>
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"use client";
2+
3+
import { useCallback } from "react";
4+
import { useAiContextStore } from "../../ai-provider";
5+
import { ActionButtonElement, ActionButtonProps, createActionButton } from "../../utils/create-action-button";
6+
7+
type SelectModelProps = {
8+
model: string;
9+
};
10+
11+
const useComposerSelectModel = ({ model }: SelectModelProps) => {
12+
const aiContextStore = useAiContextStore();
13+
14+
return useCallback(() => {
15+
aiContextStore.setState({ selectedModel: model });
16+
}, [aiContextStore, model]);
17+
};
18+
19+
export namespace ComposerPrimitiveSelectModel {
20+
export type Element = ActionButtonElement;
21+
export type Props = ActionButtonProps<typeof useComposerSelectModel>;
22+
}
23+
24+
export const ComposerPrimitiveSelectModel = createActionButton(
25+
"ComposerPrimitive.SelectModel",
26+
useComposerSelectModel,
27+
["model"],
28+
);

packages-test-3/ai-chat/src/primitives/composer/index.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ export { ComposerPrimitiveRoot as Root, useComposer } from "./composer-provider"
22
export { ComposerPrimitiveInput as Input } from "./composer-input";
33
export { ComposerPrimitiveSend as Send,useComposerSend } from "./composer-send";
44
export { ComposerPrimitiveCancel as Cancel } from "./composer-cancel";
5+
export { ComposerPrimitiveSelectModel as SelectModel } from "./composer-select-model";
56
export { ComposerPrimitiveAttachments as Attachments, ComposerPrimitiveAttachmentByIndex as AttachmentByIndex } from "./composer-attachments";
6-
export * from './composer-provider'
7+
export * from './composer-provider'

packages-test-3/ai-chat/src/primitives/thread/thread-root.tsx

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import { ToolCallMessagePartComponent } from "../../types/message-part-component
2121
import { Unsubscribe } from "../../types/unsuscribe";
2222
import { toToolsJSONSchema } from "../../stream/schema-utils";
2323
import { Tool } from "../../stream/tool-types";
24+
import { BODY_KEY } from "../../utils/request-keys";
2425

2526
export type CustomUIDataTypes = {
2627
textDelta: string;
@@ -190,6 +191,8 @@ export function ThreadPrimitiveRoot({ children, ...value }: { children: React.Re
190191
const system = useAiContext(s => s.system);
191192
const callSettings = useAiContext(s => s.callSettings);
192193
const config = useAiContext(s => s.config);
194+
const selectedModel = useAiContext(s => s.selectedModel);
195+
const disabledTools = useAiContext(s => s.disabledTools);
193196
const activeThreadId = useThreads(s => s.activeThreadId);
194197
const [title, setTitle] = useState('New thread');
195198
const [status, setStatus] = useState<Thread['status']>('regular');
@@ -228,6 +231,8 @@ export function ThreadPrimitiveRoot({ children, ...value }: { children: React.Re
228231
system?: string;
229232
callSettings?: unknown;
230233
config?: unknown;
234+
selectedModel?: string;
235+
disabledTools?: string[];
231236
}>({ tools: {} });
232237
const addToolOutputRef = useRef<
233238
null | ((
@@ -259,6 +264,8 @@ export function ThreadPrimitiveRoot({ children, ...value }: { children: React.Re
259264
...(context.system ? { system: context.system } : {}),
260265
...(context.callSettings ? { callSettings: context.callSettings } : {}),
261266
...(context.config ? { config: context.config } : {}),
267+
[BODY_KEY.SELECTED_MODEL]: context.selectedModel ?? "",
268+
[BODY_KEY.DISABLED_TOOLS]: context.disabledTools ?? [],
262269
tools: toToolsJSONSchema(Object.values(toolsRef.current) as any) as any,
263270
},
264271
};
@@ -649,6 +656,8 @@ export function ThreadPrimitiveRoot({ children, ...value }: { children: React.Re
649656
system,
650657
callSettings,
651658
config,
659+
selectedModel: selectedModel ?? undefined,
660+
disabledTools: disabledTools ?? [],
652661
};
653662
}, []);
654663

@@ -661,10 +670,12 @@ export function ThreadPrimitiveRoot({ children, ...value }: { children: React.Re
661670
system,
662671
callSettings,
663672
config,
673+
selectedModel: selectedModel ?? undefined,
674+
disabledTools: disabledTools ?? [],
664675
};
665676
});
666677
return unsubscribe;
667-
}, [system, callSettings, config]);
678+
}, [system, callSettings, config, selectedModel, disabledTools]);
668679

669680
useEffect(() => {
670681
if (!storeRef.current) return;
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import z from "zod";
2+
import { BODY_KEY } from "../utils/request-keys";
3+
import type { streamText, Tool } from "ai";
4+
5+
export type ModelResolver = (model?: string) => string | null;
6+
7+
const defaultModelResolver: ModelResolver = (model) => model ?? null;
8+
9+
export const getModel =
10+
(body: Record<string, unknown>, resolveModel: ModelResolver = defaultModelResolver): string | null => {
11+
const model = z.string().nullable().optional().safeParse(body?.[BODY_KEY.SELECTED_MODEL]).data || undefined;
12+
return resolveModel(model);
13+
};
14+
15+
export const getDisabledToolsFiter = (body: Record<string, unknown>): ((tools: Parameters<typeof streamText>[0]['tools']) => Parameters<typeof streamText>[0]['tools']) => {
16+
const disabledTools = z.array(z.string()).nullable().optional().safeParse(body?.[BODY_KEY.DISABLED_TOOLS]).data || [];
17+
return (serverTools: Parameters<typeof streamText>[0]['tools']) => Object.fromEntries(
18+
Object.entries(serverTools as Record<string, Tool<any, any>>).filter(([toolName]) => !disabledTools.includes(toolName)),
19+
)
20+
};
21+
22+
export { BODY_KEY } from "../utils/request-keys";
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
export const BODY_KEY = {
2+
SELECTED_MODEL: "selected-model",
3+
DISABLED_TOOLS: "disabled-tools",
4+
}

0 commit comments

Comments
 (0)