Skip to content

Commit dd57c18

Browse files
authored
Flexibly exclude tools (#14)
1 parent 332c028 commit dd57c18

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

src/fhda/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@
1919
else:
2020
DATA_STORAGE_PATH = Path("/storage")
2121

22-
VALID_FROM_TASK_KWARGS = ["run_notebook_on_edit"]
22+
VALID_FROM_TASK_KWARGS = ["run_notebook_on_edit", "exclude_tools"]

src/fhda/data_analysis_env.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
eval_mode: EvalAnswerMode | None = None,
3737
metadata: dict[str, Any] | None = None, # used for NBEvalExpt
3838
mcqs: list[MultipleChoiceQuestion] | None = None,
39+
exclude_tools: list[str] | None = None,
3940
**kwargs,
4041
):
4142
super().__init__(**kwargs)
@@ -49,21 +50,38 @@ def __init__(
4950
self.system_prompt = system_prompt
5051
self.metadata = metadata
5152
self.question_rewards: dict[str, int] = {}
53+
self.exclude_tools = exclude_tools
5254

5355
async def reset(self) -> tuple[Messages, list[Tool]]:
5456
# Discard base class's init_obs and make our own with the problem statement
5557
_, tools = await super().reset()
58+
if self.exclude_tools:
59+
tools = [
60+
tool
61+
for tool in tools
62+
if tool._tool_fn.__name__ not in self.exclude_tools
63+
]
64+
5665
messages = [
5766
Message(content=self.problem),
5867
self.get_env_state_msg(),
5968
]
69+
# If the list_workdir tool is excluded, add the content of the working directory to the initial message
70+
if self.exclude_tools is not None and "list_workdir" in self.exclude_tools:
71+
messages.append(
72+
Message(
73+
content=f"Here is the content of your working directory:\n{self.list_workdir()}"
74+
)
75+
)
76+
6077
if self.system_prompt:
6178
messages.append(Message(role="system", content=self.system_prompt))
6279
init_obs = cast(
6380
Messages,
6481
messages,
6582
)
66-
83+
print(messages)
84+
print(tools)
6785
return init_obs, tools
6886

6987
async def submit_answer(self, answer: str) -> str: # type: ignore[override]

0 commit comments

Comments
 (0)