Skip to content

Commit e259f7b

Browse files
vid277cdxker
authored andcommitted
feature: add must not filter tool call
1 parent 62c2194 commit e259f7b

File tree

9 files changed

+230
-14
lines changed

9 files changed

+230
-14
lines changed

clients/search-component/src/utils/hooks/chat-context.tsx

Lines changed: 88 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import {
44
defaultPriceToolCallOptions,
55
defaultRelevanceToolCallOptions,
66
defaultSearchToolCallOptions,
7+
defaultNotFilterToolCallOptions,
78
useModalState,
89
} from "./modal-context";
910
import { Chunk } from "../types";
@@ -91,18 +92,18 @@ const ChatContext = createContext<{
9192
rateChatCompletion: (isPositive: boolean, queryId: string | null) => void;
9293
productsWithClicks: ChunkIdWithIndex[];
9394
}>({
94-
askQuestion: async () => { },
95+
askQuestion: async () => {},
9596
currentQuestion: "",
9697
isLoading: false,
9798
loadingText: "",
9899
messages: [],
99-
setCurrentQuestion: () => { },
100-
cancelGroupChat: () => { },
101-
clearConversation: () => { },
102-
chatWithGroup: () => { },
103-
switchToChatAndAskQuestion: async () => { },
104-
stopGeneratingMessage: () => { },
105-
rateChatCompletion: () => { },
100+
setCurrentQuestion: () => {},
101+
cancelGroupChat: () => {},
102+
clearConversation: () => {},
103+
chatWithGroup: () => {},
104+
switchToChatAndAskQuestion: async () => {},
105+
stopGeneratingMessage: () => {},
106+
rateChatCompletion: () => {},
106107
productsWithClicks: [],
107108
});
108109

@@ -521,6 +522,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
521522
let useImage = false;
522523
let referenceImageUrls: string[] = [];
523524
let referenceChunks: Chunk[] = [];
525+
let notFilter = false;
524526

525527
if (!groupIds || groupIds.length === 0) {
526528
chatMessageAbortController.current = new AbortController();
@@ -645,6 +647,54 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
645647
}
646648
});
647649

650+
const notFilterPromise = retryOperation(async () => {
651+
if (!curGroup && messages.length > 1) {
652+
return await trieveSDK.getToolCallFunctionParams({
653+
user_message_text: `Here's the previous message thread so far: ${messages.map(
654+
(message) => {
655+
if (
656+
message.type === "system" &&
657+
message.additional?.length &&
658+
props.type === "ecommerce"
659+
) {
660+
const chunks = message.additional
661+
.map((chunk) => {
662+
return JSON.stringify({
663+
title: chunk.metadata?.title || "",
664+
description: chunk.chunk_html || "",
665+
price: chunk.num_value
666+
? `${props.defaultCurrency || ""} ${chunk.num_value}`
667+
: "",
668+
link: chunk.link || "",
669+
});
670+
})
671+
.join("\n\n");
672+
return `\n\n${chunks}${message.text}`;
673+
} else {
674+
return `\n\n${message.text}`;
675+
}
676+
},
677+
)} \n\n${props.notFilterToolCallOptions?.userMessageTextPrefix ?? defaultNotFilterToolCallOptions.userMessageTextPrefix}: ${questionProp || currentQuestion}.`,
678+
image_url: localImageUrl ? localImageUrl : null,
679+
audio_input: curAudioBase64 ? curAudioBase64 : null,
680+
tool_function: {
681+
name: "not_filter",
682+
description:
683+
props.notFilterToolCallOptions?.toolDescription ??
684+
defaultNotFilterToolCallOptions.toolDescription,
685+
parameters: [
686+
{
687+
name: "not_filter",
688+
parameter_type: "boolean",
689+
description:
690+
"Whether or not the user is interested in the products previously shown to them. Set this to true if the user is not interested in the products they were shown or want something different.",
691+
},
692+
],
693+
},
694+
});
695+
}
696+
});
697+
648698
const tagFiltersPromise = retryOperation(async () => {
649699
if (
650700
(!defaultMatchAnyTags || !defaultMatchAnyTags?.length) &&
@@ -693,11 +743,13 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
693743
imageFiltersResp,
694744
tagFiltersResp,
695745
skipSearchResp,
746+
notFilterResp,
696747
] = await Promise.all([
697748
priceFiltersPromise,
698749
imageFiltersPromise,
699750
tagFiltersPromise,
700751
skipSearchPromise,
752+
notFilterPromise,
701753
]);
702754

703755
if (transcribedQuery && curAudioBase64) {
@@ -725,7 +777,8 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
725777
}
726778

727779
useImage = (imageFiltersResp?.parameters &&
728-
(imageFiltersResp.parameters as any)["image"] === true && localImageUrl) as boolean;
780+
(imageFiltersResp.parameters as any)["image"] === true &&
781+
localImageUrl) as boolean;
729782

730783
const match_any_tags = [];
731784
if (tagFiltersResp?.parameters) {
@@ -786,6 +839,15 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
786839
}
787840
}
788841

842+
if (notFilterResp?.parameters) {
843+
const notFilterParam = (notFilterResp.parameters as any)[
844+
"not_filter"
845+
];
846+
if (typeof notFilterParam === "boolean" && notFilterParam) {
847+
notFilter = true;
848+
}
849+
}
850+
789851
clearTimeout(toolCallTimeout);
790852
} catch (e) {
791853
console.error("error getting getToolCallFunctionParams", e);
@@ -821,6 +883,19 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
821883
filters = null;
822884
}
823885

886+
if (notFilter) {
887+
if (filters == null) {
888+
filters = { must_not: [] };
889+
} else if (filters.must_not == null) {
890+
filters.must_not = [];
891+
}
892+
893+
(filters as ChunkFilter)?.must_not?.push({
894+
field: "group_ids",
895+
match_any: groupIdsInChat,
896+
});
897+
}
898+
824899
searchAbortController.current = new AbortController();
825900
if (curGroup) {
826901
setLoadingText("Reading the product's information...");
@@ -883,7 +958,9 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
883958
const searchOverGroupsResp = await retryOperation(async () => {
884959
return await trieveSDK.searchOverGroups(
885960
{
886-
query: questionProp || currentQuestion,
961+
query: notFilter
962+
? `${messages[messages.length - 2]?.text} ${questionProp || currentQuestion}`
963+
: questionProp || currentQuestion,
887964
search_type: "fulltext",
888965
filters: filters,
889966
page_size: 20,
@@ -1296,7 +1373,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
12961373
}
12971374
if (audioBase64) {
12981375
setAudioBase64("");
1299-
setTranscribedQuery("");
1376+
setTranscribedQuery("");
13001377
}
13011378
};
13021379

clients/search-component/src/utils/hooks/modal-context.tsx

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,15 @@ export const defaultSearchToolCallOptions: SearchToolCallOptions = {
9595
"Call this tool anytime it seems like we need to skip the search step. This tool tells our system that the user is asking about what they were previously shown.",
9696
};
9797

98+
export const defaultNotFilterToolCallOptions: NotFilterToolCallOptions = {
99+
userMessageTextPrefix: "Here is the user query:",
100+
toolDescription:
101+
"Set to true if the query is not interested in the products they were shown previously or would like to see something different. Ensure that this is only set to true when the user wants to see something different from the previously returned results or is not interested in those previously returned results.",
102+
};
103+
98104
export const defaultPriceToolCallOptions: PriceToolCallOptions = {
99105
toolDescription:
100-
"Only call this function if the query includes details about a price. Decide on which price filters to apply to the available catalog being used within the knowledge base to respond. If the question is slightly like a product name, respond with no filters (all false).",
106+
"Only call this function if the query includes details about a price. Decide on which price filters to apply to the available catalog being used within the knowledge base to respond. If the question is slightly like a product name, respond with no filters (all false). If they don't specify a price number, then respond with no filters (all false).",
101107
minPriceDescription:
102108
"Minimum price of the product. Only set this if a minimum price is mentioned in the query.",
103109
maxPriceDescription:
@@ -110,6 +116,11 @@ export interface PriceToolCallOptions {
110116
maxPriceDescription?: string;
111117
}
112118

119+
export interface NotFilterToolCallOptions {
120+
userMessageTextPrefix?: string;
121+
toolDescription: string;
122+
}
123+
113124
export interface FilterSidebarSection {
114125
key: string;
115126
filterKey: string;
@@ -147,7 +158,6 @@ export function isDefaultSearchQuery(
147158
return typeof question === "object" && "query" in question;
148159
}
149160

150-
151161
export type ModalProps = {
152162
datasetId: string;
153163
apiKey: string;
@@ -182,6 +192,7 @@ export type ModalProps = {
182192
relevanceToolCallOptions?: RelevanceToolCallOptions;
183193
priceToolCallOptions?: PriceToolCallOptions;
184194
searchToolCallOptions?: SearchToolCallOptions;
195+
notFilterToolCallOptions?: NotFilterToolCallOptions;
185196
defaultSearchMode?: SearchModes;
186197
usePagefind?: boolean;
187198
type?: ModalTypes;
@@ -252,6 +263,7 @@ const defaultProps = {
252263
baseUrl: "https://api.trieve.ai",
253264
relevanceToolCallOptions: defaultRelevanceToolCallOptions,
254265
priceToolCallOptions: defaultPriceToolCallOptions,
266+
notFilterToolCallOptions: defaultNotFilterToolCallOptions,
255267
defaultSearchMode: "search" as SearchModes,
256268
placeholder: "Search...",
257269
chatPlaceholder: "Ask Anything...",

clients/ts-sdk/openapi.json

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16203,6 +16203,19 @@
1620316203
}
1620416204
]
1620516205
},
16206+
"NotFilterToolCallOptions": {
16207+
"type": "object",
16208+
"properties": {
16209+
"toolDescription": {
16210+
"type": "string",
16211+
"nullable": true
16212+
},
16213+
"userMessageTextPrefix": {
16214+
"type": "string",
16215+
"nullable": true
16216+
}
16217+
}
16218+
},
1620616219
"OcrStrategy": {
1620716220
"type": "string",
1620816221
"description": "Controls the Optical Character Recognition (OCR) strategy.\n- `All`: Processes all pages with OCR. (Latency penalty: ~0.5 seconds per page)\n- `Auto`: Selectively applies OCR only to pages with missing or low-quality text. When text layer is present the bounding boxes from the text layer are used.",
@@ -16824,6 +16837,14 @@
1682416837
"type": "string",
1682516838
"nullable": true
1682616839
},
16840+
"notFilterToolCallOptions": {
16841+
"allOf": [
16842+
{
16843+
"$ref": "#/components/schemas/NotFilterToolCallOptions"
16844+
}
16845+
],
16846+
"nullable": true
16847+
},
1682716848
"numberOfSuggestions": {
1682816849
"type": "integer",
1682916850
"nullable": true,

clients/ts-sdk/src/types.gen.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2939,6 +2939,11 @@ export type MultiQuery = {
29392939

29402940
export type NewChunkMetadataTypes = SlimChunkMetadataWithArrayTagSet | ChunkMetadata | ContentChunkMetadata;
29412941

2942+
export type NotFilterToolCallOptions = {
2943+
toolDescription?: (string) | null;
2944+
userMessageTextPrefix?: (string) | null;
2945+
};
2946+
29422947
/**
29432948
* Controls the Optical Character Recognition (OCR) strategy.
29442949
* - `All`: Processes all pages with OCR. (Latency penalty: ~0.5 seconds per page)
@@ -3135,6 +3140,7 @@ export type PublicPageParameters = {
31353140
inlineHeader?: (string) | null;
31363141
isTestMode?: (boolean) | null;
31373142
navLogoImgSrcUrl?: (string) | null;
3143+
notFilterToolCallOptions?: ((NotFilterToolCallOptions) | null);
31383144
numberOfSuggestions?: (number) | null;
31393145
openGraphMetadata?: ((OpenGraphMetadata) | null);
31403146
openLinksInNewTab?: (boolean) | null;

frontends/dashboard/src/hooks/usePublicPageSettings.tsx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import { ApiRoutes } from "../components/Routes";
1919
import { HeroPatterns } from "../pages/dataset/HeroPatterns";
2020
import { createInitializedContext } from "../utils/initialize";
2121
import {
22+
defaultNotFilterToolCallOptions,
2223
defaultOpenGraphMetadata,
2324
defaultPriceToolCallOptions,
2425
defaultRelevanceToolCallOptions,
@@ -100,6 +101,12 @@ export const { use: usePublicPage, provider: PublicPageProvider } =
100101
});
101102
}
102103

104+
if (!extraParams.notFilterToolCallOptions) {
105+
setExtraParams("notFilterToolCallOptions", {
106+
...defaultNotFilterToolCallOptions,
107+
});
108+
}
109+
103110
if (!extraParams.openGraphMetadata) {
104111
setExtraParams("openGraphMetadata", {
105112
...defaultOpenGraphMetadata,

0 commit comments

Comments
 (0)