Skip to content

Commit f3df726

Browse files
committed
Added roi object names into system prompt and fixed a bug
1 parent 5b432f2 commit f3df726

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

amadeusgpt/analysis_objects/llm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,12 @@ def update_system_prompt(self, sandbox):
229229
task_program_docs = sandbox.get_task_program_docs()
230230
query_block = sandbox.get_query_block()
231231

232-
keypoint_names = sandbox.exec_namespace[
232+
behavior_analysis = sandbox.exec_namespace[
233233
"behavior_analysis"
234-
].get_keypoint_names()
234+
]
235+
235236
self.system_prompt = _get_system_prompt(
236-
query_block, core_api_docs, task_program_docs, keypoint_names
237+
query_block, core_api_docs, task_program_docs, behavior_analysis
237238
)
238239

239240
# update both history and context window

amadeusgpt/managers/relationship_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_animals_objects_relationships(
5252
for animal in animals:
5353
if animal_bodyparts_names is not None:
5454
# the keypoints of animal get updated when we update the roi bodypart names
55-
animal.update_roi_keypoint_names(animal_bodyparts_names)
55+
animal.update_roi_keypoint_by_names(animal_bodyparts_names)
5656
for object in objs:
5757
animal_object_relations = AnimalObjectRelationship(
5858
animal, object, animal_bodyparts_names

amadeusgpt/system_prompts/code_generator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ def _get_system_prompt(
22
query,
33
core_api_docs,
44
task_program_docs,
5-
keypoint_names,
5+
behavior_analysis,
66
):
7+
keypoint_names = behavior_analysis.get_keypoint_names()
8+
roi_object_names = behavior_analysis.get_roi_object_names()
79
system_prompt = f"""
810
You are helpful AI assistant. Your job is to answer user queries.
911
Importantly, before you write the code, you need to explain whether the question can be answered accurately by code. If not, ask users to give more information.
@@ -66,6 +68,7 @@ def get_watching_events(config: Config):
6668
{query}\n{core_api_docs}\n{task_program_docs}\n
6769
6870
The keypoint names for the animals are: {keypoint_names}
71+
Available ROI objects are: {roi_object_names}
6972
7073
FORMATTING:
7174
1) If you are asked to provide plotting code, make sure you don't call plt.show() but return a tuple figure, axs

0 commit comments

Comments
 (0)