Skip to content

Commit 91d7fd2

Browse files
authored
feat: add custom prompt for QAEvalChain chain (#610)
I originally had only modified the `from_llm` to include the prompt but I realized that if the prompt keys used on the custom prompt didn't match the default prompt, it wouldn't work because of how `apply` works. So I made some changes to the evaluate method to check if the prompt is the default and if not, it will check if the input keys are the same as the prompt key and update the inputs appropriately. Let me know if there is a better way to do this. Also added the custom prompt to the QA eval notebook.
1 parent 1787c47 commit 91d7fd2

File tree

2 files changed

+73
-5
lines changed

2 files changed

+73
-5
lines changed

docs/use_cases/evaluation/question_answering.ipynb

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,51 @@
190190
" print()"
191191
]
192192
},
193+
{
194+
"attachments": {},
195+
"cell_type": "markdown",
196+
"id": "782ae8c8",
197+
"metadata": {},
198+
"source": [
199+
"## Customize Prompt\n",
200+
"\n",
201+
"You can also customize the prompt that is used. Here is an example prompting it using a score from 0 to 10.\n",
202+
"The custom prompt requires 3 input variables: \"query\", \"answer\" and \"result\". Where \"query\" is the question, \"answer\" is the ground truth answer, and \"result\" is the predicted answer."
203+
]
204+
},
205+
{
206+
"cell_type": "code",
207+
"execution_count": null,
208+
"id": "153425c4",
209+
"metadata": {},
210+
"outputs": [],
211+
"source": [
212+
"from langchain.prompts.prompt import PromptTemplate\n",
213+
"\n",
214+
"_PROMPT_TEMPLATE = \"\"\"You are an expert professor specialized in grading students' answers to questions.\n",
215+
"You are grading the following question:\n",
216+
"{query}\n",
217+
"Here is the real answer:\n",
218+
"{answer}\n",
219+
"You are grading the following predicted answer:\n",
220+
"{result}\n",
221+
"What grade do you give from 0 to 10, where 0 is the lowest (very low similarity) and 10 is the highest (very high similarity)?\n",
222+
"\"\"\"\n",
223+
"\n",
224+
"PROMPT = PromptTemplate(input_variables=[\"query\", \"answer\", \"result\"], template=_PROMPT_TEMPLATE)"
225+
]
226+
},
227+
{
228+
"cell_type": "code",
229+
"execution_count": null,
230+
"id": "0a3b0fb7",
231+
"metadata": {},
232+
"outputs": [],
233+
"source": [
234+
"evalchain = QAEvalChain.from_llm(llm=llm,prompt=PROMPT)\n",
235+
"evalchain.evaluate(examples, predictions, question_key=\"question\", answer_key=\"answer\", prediction_key=\"text\")"
236+
]
237+
},
193238
{
194239
"cell_type": "markdown",
195240
"id": "aaa61f0c",
@@ -271,7 +316,7 @@
271316
],
272317
"metadata": {
273318
"kernelspec": {
274-
"display_name": "Python 3 (ipykernel)",
319+
"display_name": ".venv",
275320
"language": "python",
276321
"name": "python3"
277322
},
@@ -285,7 +330,12 @@
285330
"name": "python",
286331
"nbconvert_exporter": "python",
287332
"pygments_lexer": "ipython3",
288-
"version": "3.10.9"
333+
"version": "3.9.7 (default, Sep 16 2021, 08:50:36) \n[Clang 10.0.0 ]"
334+
},
335+
"vscode": {
336+
"interpreter": {
337+
"hash": "53f3bc57609c7a84333bb558594977aa5b4026b1d6070b93987956689e367341"
338+
}
289339
}
290340
},
291341
"nbformat": 4,

langchain/evaluation/qa/eval_chain.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from typing import Any, List
55

6+
from langchain import PromptTemplate
67
from langchain.chains.llm import LLMChain
78
from langchain.evaluation.qa.eval_prompt import PROMPT
89
from langchain.llms.base import BaseLLM
@@ -12,9 +13,25 @@ class QAEvalChain(LLMChain):
1213
"""LLM Chain specifically for evaluating question answering."""
1314

1415
@classmethod
15-
def from_llm(cls, llm: BaseLLM, **kwargs: Any) -> QAEvalChain:
16-
"""Load QA Eval Chain from LLM."""
17-
return cls(llm=llm, prompt=PROMPT, **kwargs)
16+
def from_llm(
17+
cls, llm: BaseLLM, prompt: PromptTemplate = PROMPT, **kwargs: Any
18+
) -> QAEvalChain:
19+
"""Load QA Eval Chain from LLM.
20+
21+
Args:
22+
llm (BaseLLM): the base language model to use.
23+
24+
prompt (PromptTemplate): A prompt template containing the input_variables:
25+
'input', 'answer' and 'result' that will be used as the prompt
26+
for evaluation.
27+
Defaults to PROMPT.
28+
29+
**kwargs: additional keyword arguments.
30+
31+
Returns:
32+
QAEvalChain: the loaded QA eval chain.
33+
"""
34+
return cls(llm=llm, prompt=prompt, **kwargs)
1835

1936
def evaluate(
2037
self,
@@ -33,4 +50,5 @@ def evaluate(
3350
"result": predictions[i][prediction_key],
3451
}
3552
inputs.append(_input)
53+
3654
return self.apply(inputs)

0 commit comments

Comments
 (0)