Skip to content

Commit 0bb326c

Browse files
authored
Merge pull request #6 from DiTo97/patch-1
QoL changes to ReAct agent
2 parents 3b3f80c + 5f3e2a2 commit 0bb326c

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

synalinks/src/modules/agents/react_agent.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from synalinks.src.utils.tool_utils import Tool
1010

1111

12+
_fn_END = "finish"
13+
14+
1215
def get_decision_question():
1316
"""The default question used for decision-making"""
1417
return "Choose the next function to use based on its name."
@@ -172,13 +175,18 @@ def __init__(
172175
if language_model:
173176
self.decision_language_model = language_model
174177
self.action_language_model = language_model
175-
else:
178+
elif action_language_model and decision_language_model:
176179
self.decision_language_model = decision_language_model
177180
self.action_language_model = action_language_model
181+
else:
182+
raise ValueError(
183+
"You must set either `language_model` "
184+
" or both `action_language_model` and `decision_language_model`."
185+
)
178186

179187
self.prompt_template = prompt_template
180188

181-
if examples:
189+
if not examples:
182190
examples = []
183191
self.examples = examples
184192

@@ -191,8 +199,8 @@ def __init__(
191199
if return_inputs_only and return_inputs_with_trajectory:
192200
raise ValueError(
193201
"You cannot set both "
194-
"`return_inputs_only` and `return_inputs_with_trajectory`"
195-
" arguments to True. Choose only one."
202+
"`return_inputs_only` and `return_inputs_with_trajectory` "
203+
"arguments to True. Choose only one."
196204
)
197205
self.return_inputs_with_trajectory = return_inputs_with_trajectory
198206
self.return_inputs_only = return_inputs_only
@@ -209,7 +217,9 @@ def __init__(
209217
for fn in self.functions:
210218
self.labels.append(Tool(fn).name())
211219

212-
self.labels.append("finish")
220+
assert _fn_END not in self.labels, f"'{_fn_END}' is a reserved keyword and cannot be used as function name"
221+
222+
self.labels.append(_fn_END)
213223

214224
async def build(self, inputs):
215225
current_steps = [inputs]

0 commit comments

Comments
 (0)