Skip to content

Commit 0de5504

Browse files
authored
return code for pal (#844)
1 parent d564308 commit 0de5504

File tree

2 files changed

+121
-6
lines changed

2 files changed

+121
-6
lines changed

docs/modules/chains/examples/pal.ipynb

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,31 @@
2121
"from langchain import OpenAI"
2222
]
2323
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": null,
27+
"id": "9a58e15e",
28+
"metadata": {},
29+
"outputs": [],
30+
"source": [
31+
"llm = OpenAI(model_name='code-davinci-002', temperature=0, max_tokens=512)"
32+
]
33+
},
34+
{
35+
"cell_type": "markdown",
36+
"id": "095adc76",
37+
"metadata": {},
38+
"source": [
39+
"## Math Prompt"
40+
]
41+
},
2442
{
2543
"cell_type": "code",
2644
"execution_count": 2,
2745
"id": "beddcac7",
2846
"metadata": {},
2947
"outputs": [],
3048
"source": [
31-
"llm = OpenAI(model_name='code-davinci-002', temperature=0, max_tokens=512)\n",
3249
"pal_chain = PALChain.from_math_prompt(llm, verbose=True)"
3350
]
3451
},
@@ -64,7 +81,7 @@
6481
" result = total_pets\n",
6582
" return result\u001b[0m\n",
6683
"\n",
67-
"\u001b[1m> Finished PALChain chain.\u001b[0m\n"
84+
"\u001b[1m> Finished chain.\u001b[0m\n"
6885
]
6986
},
7087
{
@@ -82,14 +99,21 @@
8299
"pal_chain.run(question)"
83100
]
84101
},
102+
{
103+
"cell_type": "markdown",
104+
"id": "0269d20a",
105+
"metadata": {},
106+
"source": [
107+
"## Colored Objects"
108+
]
109+
},
85110
{
86111
"cell_type": "code",
87112
"execution_count": 5,
88113
"id": "e524f81f",
89114
"metadata": {},
90115
"outputs": [],
91116
"source": [
92-
"llm = OpenAI(model_name='code-davinci-002', temperature=0, max_tokens=512)\n",
93117
"pal_chain = PALChain.from_colored_object_prompt(llm, verbose=True)"
94118
]
95119
},
@@ -147,10 +171,94 @@
147171
"pal_chain.run(question)"
148172
]
149173
},
174+
{
175+
"cell_type": "markdown",
176+
"id": "fc3d7f10",
177+
"metadata": {},
178+
"source": [
179+
"## Intermediate Steps\n",
180+
"You can also use the intermediate steps flag to return the code executed that generates the answer."
181+
]
182+
},
183+
{
184+
"cell_type": "code",
185+
"execution_count": 5,
186+
"id": "9d2d9c61",
187+
"metadata": {},
188+
"outputs": [],
189+
"source": [
190+
"pal_chain = PALChain.from_colored_object_prompt(llm, verbose=True, return_intermediate_steps=True)"
191+
]
192+
},
193+
{
194+
"cell_type": "code",
195+
"execution_count": 6,
196+
"id": "b29b971b",
197+
"metadata": {},
198+
"outputs": [],
199+
"source": [
200+
"question = \"On the desk, you see two blue booklets, two purple booklets, and two yellow pairs of sunglasses. If I remove all the pairs of sunglasses from the desk, how many purple items remain on it?\""
201+
]
202+
},
203+
{
204+
"cell_type": "code",
205+
"execution_count": 8,
206+
"id": "a2c40c28",
207+
"metadata": {},
208+
"outputs": [
209+
{
210+
"name": "stdout",
211+
"output_type": "stream",
212+
"text": [
213+
"\n",
214+
"\n",
215+
"\u001b[1m> Entering new PALChain chain...\u001b[0m\n",
216+
"\u001b[32;1m\u001b[1;3m# Put objects into a list to record ordering\n",
217+
"objects = []\n",
218+
"objects += [('booklet', 'blue')] * 2\n",
219+
"objects += [('booklet', 'purple')] * 2\n",
220+
"objects += [('sunglasses', 'yellow')] * 2\n",
221+
"\n",
222+
"# Remove all pairs of sunglasses\n",
223+
"objects = [object for object in objects if object[0] != 'sunglasses']\n",
224+
"\n",
225+
"# Count number of purple objects\n",
226+
"num_purple = len([object for object in objects if object[1] == 'purple'])\n",
227+
"answer = num_purple\u001b[0m\n",
228+
"\n",
229+
"\u001b[1m> Finished chain.\u001b[0m\n"
230+
]
231+
}
232+
],
233+
"source": [
234+
"result = pal_chain({\"question\": question})"
235+
]
236+
},
237+
{
238+
"cell_type": "code",
239+
"execution_count": 11,
240+
"id": "efddd033",
241+
"metadata": {},
242+
"outputs": [
243+
{
244+
"data": {
245+
"text/plain": [
246+
"\"# Put objects into a list to record ordering\\nobjects = []\\nobjects += [('booklet', 'blue')] * 2\\nobjects += [('booklet', 'purple')] * 2\\nobjects += [('sunglasses', 'yellow')] * 2\\n\\n# Remove all pairs of sunglasses\\nobjects = [object for object in objects if object[0] != 'sunglasses']\\n\\n# Count number of purple objects\\nnum_purple = len([object for object in objects if object[1] == 'purple'])\\nanswer = num_purple\""
247+
]
248+
},
249+
"execution_count": 11,
250+
"metadata": {},
251+
"output_type": "execute_result"
252+
}
253+
],
254+
"source": [
255+
"result['intermediate_steps']"
256+
]
257+
},
150258
{
151259
"cell_type": "code",
152260
"execution_count": null,
153-
"id": "4ab20fec",
261+
"id": "dfd88594",
154262
"metadata": {},
155263
"outputs": [],
156264
"source": []

langchain/chains/pal/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class PALChain(Chain, BaseModel):
2727
python_globals: Optional[Dict[str, Any]] = None
2828
python_locals: Optional[Dict[str, Any]] = None
2929
output_key: str = "result" #: :meta private:
30+
return_intermediate_steps: bool = False
3031

3132
class Config:
3233
"""Configuration for this pydantic object."""
@@ -48,7 +49,10 @@ def output_keys(self) -> List[str]:
4849
4950
:meta private:
5051
"""
51-
return [self.output_key]
52+
if not self.return_intermediate_steps:
53+
return [self.output_key]
54+
else:
55+
return [self.output_key, "intermediate_steps"]
5256

5357
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
5458
llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)
@@ -58,7 +62,10 @@ def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
5862
)
5963
repl = PythonREPL(_globals=self.python_globals, _locals=self.python_locals)
6064
res = repl.run(code + f"\n{self.get_answer_expr}")
61-
return {self.output_key: res.strip()}
65+
output = {self.output_key: res.strip()}
66+
if self.return_intermediate_steps:
67+
output["intermediate_steps"] = code
68+
return output
6269

6370
@classmethod
6471
def from_math_prompt(cls, llm: BaseLLM, **kwargs: Any) -> PALChain:

0 commit comments

Comments
 (0)