Skip to content

Commit e7fb749

Browse files
committed
Changed the out video name to use function name
1 parent 23db4ac commit e7fb749

File tree

5 files changed

+18
-11
lines changed

5 files changed

+18
-11
lines changed

amadeusgpt/app_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,11 @@ def render(self):
217217
filename = save_figure_to_tempfile(fig)
218218
st.image(filename, width=600)
219219
elif render_key == "out_videos":
220+
print ('out_videos')
221+
print (render_value)
220222
for video_path in render_value:
221-
st.video(video_path)
223+
if os.path.exists(video_path):
224+
st.video(video_path)
222225

223226

224227
class Messages:

amadeusgpt/managers/event_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,13 +192,16 @@ def get_animals_state_events(
192192
state = self.animal_manager.query_animal_states(
193193
sender_animal_name, query_name
194194
)
195+
196+
195197
# must be of shape (n_frames, n_kpts, n_dim)
196198
assert (
197199
len(state.shape) == 3
198200
), f"state shape is {state.shape}. It must be of shape (n_frames, n_kpts, n_dim)"
199201
if len(state.shape) == 3:
200202
state = np.nanmedian(state, axis=(1, 2))
201203
relation_string = "state" + comparison
204+
202205
mask = eval(relation_string)
203206

204207
events = Event.mask2events(

amadeusgpt/managers/visual_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ def write_video(self, out_folder, video_file_path, out_name, events):
546546
# sort the data by start_time
547547
data = sorted(data, key=lambda x: x["start_time"])
548548
total_duration = sum([event.duration_in_seconds for event in events])
549-
if total_duration < 0.5:
549+
if total_duration < 0.0:
550550
return
551551

552552
fourcc = cv2.VideoWriter_fourcc(*"avc1") # Adjust the codec as needed

amadeusgpt/programs/sandbox.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -372,12 +372,12 @@ def register_task_program(self, code, parents=None, mutation_from=None):
372372
def register_llm(self, name, llm):
373373
self.llms[name] = llm
374374

375-
def events_to_videos(self, events, query):
375+
def events_to_videos(self, events, function_name):
376376
behavior_analysis = self.exec_namespace["behavior_analysis"]
377377
visual_manager = behavior_analysis.visual_manager
378378
out_folder = "event_clips"
379379
os.makedirs(out_folder, exist_ok=True)
380-
behavior_name = "_".join(query.split(" "))
380+
behavior_name = "_".join(function_name.split(" "))
381381
video_file = self.config["video_info"]["video_file_path"]
382382
return visual_manager.generate_video_clips_from_events(
383383
out_folder, video_file, events, behavior_name
@@ -411,7 +411,7 @@ def render_qa_message(self, qa_message):
411411
)
412412
)
413413
qa_message["out_videos"].append(
414-
self.events_to_videos(e, qa_message["query"])
414+
self.events_to_videos(e, self.get_function_name_from_string(qa_message["code"]))
415415
)
416416

417417
elif (
@@ -429,7 +429,7 @@ def render_qa_message(self, qa_message):
429429
visual_manager.get_ethogram_visualization(events=function_rets)
430430
)
431431
qa_message["out_videos"].append(
432-
self.events_to_videos(function_rets, qa_message["query"])
432+
self.events_to_videos(function_rets, self.get_function_name_from_string(qa_message["code"]))
433433
)
434434
else:
435435
pass
@@ -521,19 +521,20 @@ def render_temp_message(query, sandbox):
521521
print("after code execution")
522522
print(len(qa_message["function_rets"]))
523523
events = qa_message["function_rets"]
524-
# for event in events:
525-
# print (event.start, event.end)
526-
# qa_message = sandbox.render_qa_message(qa_message)
527-
524+
for event in events:
525+
print (event)
526+
528527
if qa_message["function_rets"] is not None:
529528
st.markdown(qa_message["function_rets"])
530529

531530
plots = qa_message["plots"]
531+
print ('plots', plots)
532532
for fig, axe in plots:
533533
filename = save_figure_to_tempfile(fig)
534534
st.image(filename, width=600)
535535

536536
videos = qa_message["out_videos"]
537+
print ('videos', videos)
537538
for video in videos:
538539
st.video(video)
539540

examples/MABe/example.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
{
33
"role": "human",
44
"error": NaN,
5-
"query": "Define <|chases|> as a social behavior where closest distance between this animal and other animals is less than 40 pixels and the angle between this and other animals have to be less than 30 and this animal has to travel faster than 2."
5+
"query": "Define <|chases|> as a social behavior where closest distance between this animal and other animals is less than 40 pixels and the angle between this and other animals have to be less than 30 and this animal has to travel faster than 0.2."
66
},
77
{
88
"role": "human",

0 commit comments

Comments
 (0)