Skip to content

Commit 7343ffe

Browse files
committed
Modify to add a custom prompt in RFF retriever
1 parent 285bb9e commit 7343ffe

File tree

3 files changed

+50
-12
lines changed

3 files changed

+50
-12
lines changed

packages/components/nodes/retrievers/RRFRetriever/RRFRetriever.ts

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,26 @@ class RRFRetriever_Retrievers implements INode {
7777
default: 60,
7878
additionalParams: true,
7979
optional: true
80+
},
81+
{
82+
label: 'System Message',
83+
name: 'systemMessage',
84+
description: 'System message for query generation. Leave empty to use default.',
85+
type: 'string',
86+
rows: 2,
87+
placeholder: 'You are a helpful assistant that generates multiple search queries based on a single input query.',
88+
additionalParams: true,
89+
optional: true
90+
},
91+
{
92+
label: 'Query Generation Prompt',
93+
name: 'queryPrompt',
94+
description: 'Prompt template for generating query variations. Use {input} to refer to the original query.',
95+
type: 'string',
96+
rows: 4,
97+
placeholder: 'Generate multiple search queries related to: {input}. Provide these alternative questions separated by newlines, do not add any numbers.',
98+
additionalParams: true,
99+
optional: true
80100
}
81101
]
82102
this.outputs = [
@@ -110,9 +130,11 @@ class RRFRetriever_Retrievers implements INode {
110130
const k = topK ? parseFloat(topK) : (baseRetriever as VectorStoreRetriever).k ?? 4
111131
const constantC = nodeData.inputs?.c as string
112132
const c = topK ? parseFloat(constantC) : 60
133+
const systemMessage = nodeData.inputs?.systemMessage as string
134+
const queryPrompt = nodeData.inputs?.queryPrompt as string
113135
const output = nodeData.outputs?.output as string
114136

115-
const ragFusion = new ReciprocalRankFusion(llm, baseRetriever as VectorStoreRetriever, q, k, c)
137+
const ragFusion = new ReciprocalRankFusion(llm, baseRetriever as VectorStoreRetriever, q, k, c, systemMessage, queryPrompt)
116138
const retriever = new ContextualCompressionRetriever({
117139
baseCompressor: ragFusion,
118140
baseRetriever: baseRetriever

packages/components/nodes/retrievers/RRFRetriever/ReciprocalRankFusion.ts

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,26 @@ export class ReciprocalRankFusion extends BaseDocumentCompressor {
1111
private readonly queryCount: number
1212
private readonly topK: number
1313
private readonly c: number
14+
private readonly systemMessage?: string
15+
private readonly queryPrompt?: string
1416
private baseRetriever: VectorStoreRetriever
15-
constructor(llm: BaseLanguageModel, baseRetriever: VectorStoreRetriever, queryCount: number, topK: number, c: number) {
17+
constructor(
18+
llm: BaseLanguageModel,
19+
baseRetriever: VectorStoreRetriever,
20+
queryCount: number,
21+
topK: number,
22+
c: number,
23+
systemMessage?: string,
24+
queryPrompt?: string
25+
) {
1626
super()
1727
this.queryCount = queryCount
1828
this.llm = llm
1929
this.baseRetriever = baseRetriever
2030
this.topK = topK
2131
this.c = c
32+
this.systemMessage = systemMessage
33+
this.queryPrompt = queryPrompt
2234
}
2335
async compressDocuments(
2436
documents: Document<Record<string, any>>[],
@@ -29,13 +41,17 @@ export class ReciprocalRankFusion extends BaseDocumentCompressor {
2941
if (documents.length === 0) {
3042
return []
3143
}
44+
// Use custom prompts if provided, otherwise use defaults
45+
const defaultSystemMessage = 'You are a helpful assistant that generates multiple search queries based on a single input query.'
46+
const defaultQueryPrompt =
47+
'Generate multiple search queries related to: {input}. Provide these alternative questions separated by newlines, do not add any numbers.'
48+
49+
const systemMsg = this.systemMessage || defaultSystemMessage
50+
const queryMsg = this.queryPrompt || defaultQueryPrompt
51+
3252
const chatPrompt = ChatPromptTemplate.fromMessages([
33-
SystemMessagePromptTemplate.fromTemplate(
34-
'You are a helpful assistant that generates multiple search queries based on a single input query.'
35-
),
36-
HumanMessagePromptTemplate.fromTemplate(
37-
'Generate multiple search queries related to: {input}. Provide these alternative questions separated by newlines, do not add any numbers.'
38-
),
53+
SystemMessagePromptTemplate.fromTemplate(systemMsg),
54+
HumanMessagePromptTemplate.fromTemplate(queryMsg),
3955
HumanMessagePromptTemplate.fromTemplate('OUTPUT (' + this.queryCount + ' queries):')
4056
])
4157
const llmChain = new LLMChain({
@@ -50,7 +66,7 @@ export class ReciprocalRankFusion extends BaseDocumentCompressor {
5066
})
5167
const docList: Document<Record<string, any>>[][] = []
5268
for (let i = 0; i < queries.length; i++) {
53-
const resultOne = await this.baseRetriever.vectorStore.similaritySearch(queries[i], 5, this.baseRetriever.filter)
69+
const resultOne = await this.baseRetriever.vectorStore.similaritySearch(queries[i], this.topK, this.baseRetriever.filter)
5470
const docs: any[] = []
5571
resultOne.forEach((doc) => {
5672
docs.push(doc)

packages/ui/src/ui-component/dialog/HistoryDialog.jsx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { useState, useEffect } from 'react'
1+
import { useState, useEffect, useCallback } from 'react'
22
import { createPortal } from 'react-dom'
33
import {
44
Dialog,
@@ -43,7 +43,7 @@ const HistoryDialog = ({ show, dialogProps, onCancel, onRestore }) => {
4343
const dispatch = useDispatch()
4444
useNotifier() // Side effect hook
4545

46-
const enqueueSnackbar = (...args) => dispatch(enqueueSnackbarAction(...args))
46+
const enqueueSnackbar = useCallback((...args) => dispatch(enqueueSnackbarAction(...args)), [dispatch])
4747

4848
const [historyItems, setHistoryItems] = useState([])
4949
const [loading, setLoading] = useState(false)
@@ -85,7 +85,7 @@ const HistoryDialog = ({ show, dialogProps, onCancel, onRestore }) => {
8585
}
8686

8787
loadHistory()
88-
}, [show, entityType, entityId, currentPage, reloadTrigger, latestVersion, enqueueSnackbar])
88+
}, [show, entityType, entityId, currentPage, reloadTrigger, enqueueSnackbar])
8989

9090
const handleRestore = async (historyItem) => {
9191
const confirmed = await confirm({

0 commit comments

Comments
 (0)