Skip to content

Commit 5068b4f

Browse files
authored
Fixed duplicate figure due to deepcopy of amadeus_answer (#22)
1 parent ff802aa commit 5068b4f

File tree

5 files changed

+42
-62
lines changed

5 files changed

+42
-62
lines changed

amadeusgpt/app_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def __init__(self, amadeus_answer=None, json_entry=None):
127127
self.data = {}
128128
self.data["role"] = "ai"
129129
if amadeus_answer:
130-
self.data.update(amadeus_answer.asdict())
130+
self.data.update(amadeus_answer)
131131

132132
def render(self):
133133
"""
@@ -305,7 +305,8 @@ def summon_the_beast():
305305
def ask_amadeus(question):
306306
answer = AMADEUS.chat_iteration(
307307
question
308-
) # use chat_iteration to support some magic commands
308+
).asdict() # use chat_iteration to support some magic commands
309+
309310
# Get the current process
310311
AmadeusLogger.log_process_memory(log_position="ask_amadeus")
311312
return answer

amadeusgpt/brains/base.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -198,23 +198,23 @@ def manage_memory(cls, user_msg, bot_answer):
198198

199199
# it's maybe important to keep the answer in the context window. Otherwise we are teaching the model to output empty string
200200
# need to be very careful as LLMs fewshot learning can wrongly link the answer (even if it is invalid) to the question
201-
if bot_answer["ndarray"]:
202-
cls.update_history("assistant", bot_answer["function_code"])
201+
if bot_answer.ndarray:
202+
cls.update_history("assistant", bot_answer.function_code)
203203
else:
204204

205205
answer_for_memory = ''
206206

207-
if bot_answer["function_code"]:
208-
for code in bot_answer["function_code"]:
207+
if bot_answer.function_code:
208+
for code in bot_answer.function_code:
209209
answer_for_memory += code
210210

211-
answer_for_memory += '\n' + bot_answer['str_answer']
211+
answer_for_memory += '\n' + bot_answer.str_answer
212212

213213
captions = ''
214-
if isinstance(bot_answer['plots'], list):
215-
for plot in bot_answer['plots']:
216-
if plot['plot_caption'] !='':
217-
captions+=plot['plot_caption']
214+
if isinstance(bot_answer.plots, list):
215+
for plot in bot_answer.plots:
216+
if plot.plot_caption !='':
217+
captions+=plot.plot_caption
218218
if captions!='':
219219
answer_for_memory+=captions
220220

@@ -279,12 +279,10 @@ def manage_task_programs(cls, symbol_name, bot_answer):
279279

280280
# if there is valid function code and there is a corresponding task program in the task program table
281281
if (
282-
"function_code" in bot_answer
282+
bot_answer.function_code
283283
and symbol_name in AnimalBehaviorAnalysis.task_programs
284284
):
285-
AnimalBehaviorAnalysis.task_programs[symbol_name] = bot_answer[
286-
"function_code"
287-
]
285+
AnimalBehaviorAnalysis.task_programs[symbol_name] = bot_answer.function_code
288286

289287
@classmethod
290288
def print_history(cls):

amadeusgpt/datamodel/amadeus_answer.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
class Figure:
2020
plot_type: str
2121
axes: list
22-
figure: plt.Figure
22+
figure: plt.Figure = field(default=None)
2323
plot_caption: str = ""
2424

2525

@@ -38,12 +38,6 @@ class AmadeusAnswer:
3838
ndarray: List[any] = field(default_factory=list)
3939
role = "ai"
4040

41-
def __getitem__(self, key):
42-
return getattr(self, key)
43-
44-
def __setitem__(self, key, value):
45-
setattr(self, key, value)
46-
4741
def asdict(self):
4842
return dataclasses.asdict(self)
4943

@@ -72,7 +66,7 @@ def parse_plot_tuple(self, tu, plot_type="general_plot"):
7266
data["plot_caption"] = ""
7367

7468
if "figure" in data and "axes" in data:
75-
ret = Figure(**data)
69+
ret = Figure(**data)
7670
else:
7771
ret = None
7872

@@ -149,27 +143,25 @@ def from_function_returns(cls, function_returns, function_code, thought_process)
149143
function_returns: Tuple(ret1, ret2, ret3 ... )
150144
populate the data fields from function returns.
151145
"""
152-
instance = AmadeusAnswer()
146+
instance = AmadeusAnswer()
153147
instance.function_code = function_code
154148
instance.chain_of_thoughts = thought_process
155-
# If the function returns are tuple, try to parse the tuple and generate plots
149+
# If the function returns are tuple, try to parse the tuple and generate plots
156150
if isinstance(function_returns, tuple):
157151
# if the returns contain plots, the return must be tuple (fig, axes)
158152
_plots = instance.parse_plot_tuple(function_returns)
159153
if _plots:
160-
instance.plots.append(_plots)
161-
154+
instance.plots.append(_plots)
162155
# without wrapping it in a list, the following for loop can cause problems
163156
if isinstance(function_returns, tuple):
164157
function_returns = list(function_returns)
165158
else:
166159
function_returns = [function_returns]
167-
160+
168161
for function_return in function_returns:
169162
if isinstance(function_return, (pd.Series, pd.DataFrame, np.ndarray)):
170-
if not isinstance(function_return, np.ndarray):
171-
function_return = function_return.to_numpy()
172-
instance.ndarray.append(function_return)
163+
if isinstance(function_return, (pd.Series,pd.DataFrame)):
164+
function_return = function_return.to_numpy()
173165
elif isinstance(function_return, AnimalEvent):
174166
instance.get_plots_for_animal_events(function_return)
175167
elif isinstance(function_return, AnimalAnimalEvent):
@@ -180,6 +172,4 @@ def from_function_returns(cls, function_returns, function_code, thought_process)
180172
(matplotlib.figure.Figure, matplotlib.axes._axes.Axes),
181173
):
182174
instance.str_answer = str(function_return)
183-
184-
185175
return instance

amadeusgpt/gui.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def select_roi_from_video(video_filename):
5656
fig, axs = plt.subplots(1)
5757
axs.imshow(frame)
5858
selector = ROISelector(axs)
59-
plt.show()
59+
#plt.show()
6060
return selector.paths
6161

6262

@@ -94,4 +94,4 @@ def select_roi_from_plot(fig, ax):
9494
plt.ylabel("Count of True values in mask")
9595
plt.title("Animal occurrence in ROI")
9696
# Show the plot
97-
plt.show()
97+
#plt.show()

amadeusgpt/main.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -305,45 +305,38 @@ def chat(
305305
text,
306306
function_code,
307307
thought_process,
308-
) = cls.code_generator_brain.parse_openai_response(response)
309-
310-
311-
312-
308+
) = cls.code_generator_brain.parse_openai_response(response)
313309

314310
with open("temp_for_debug.json", "w") as f:
315311
out = {'function_code': function_code,
316312
'query': rephrased_user_msg}
317313
json.dump(out, f, indent=4)
318-
# handle_function_codes gives the answer with function outputs
314+
# handle_function_codes gives the answer with function outputs
319315
amadeus_answer = cls.core_loop(
320316
rephrased_user_msg, text, function_code, thought_process
321317
)
322-
amadeus_answer = amadeus_answer.asdict()
323-
324318
# export the generated function to code_output
325-
if amadeus_answer['function_code'] and code_output != "":
319+
if amadeus_answer.function_code and code_output != "":
326320
cls.export_function_code(
327-
original_user_msg, amadeus_answer["function_code"], code_output
321+
original_user_msg, amadeus_answer.function_code, code_output
328322
)
329323

330324
# if there is an error or the function code is empty, we want to make sure we prevent ChatGPT to learn to output nothing from few-shot learning
331325
# is this used anymore?
332-
if amadeus_answer["has_error"]:
326+
if amadeus_answer.has_error:
333327
cls.code_generator_brain.context_window[-1][
334328
"content"
335329
] += "\n While executing the code above, there was error so it is not correct answer\n"
336330

337-
elif amadeus_answer["has_error"]:
331+
elif amadeus_answer.has_error:
338332
cls.code_generator_brain.context_window.pop()
339333
cls.code_generator_brain.history.pop()
340334
else:
341335
# needs to manage memory of Amadeus for context window management and state restore etc.
342336
# we have it remember user's original question instead of the rephrased one for better
343337
cls.code_generator_brain.manage_memory(
344-
original_user_msg, copy.deepcopy(amadeus_answer)
345-
)
346-
338+
original_user_msg, amadeus_answer
339+
)
347340
return amadeus_answer
348341

349342
# this should become an async function so the user can continue to ask question
@@ -354,7 +347,7 @@ def execute_python_function(
354347
function_code,
355348
):
356349
# we might register a few helper functions into globals()
357-
result = None
350+
result = None
358351
exec(function_code, globals())
359352
if "task_program" not in globals():
360353
return None
@@ -534,8 +527,8 @@ def collect_function_result(cls, function_returns, function_code, thought_proces
534527
function_returns, function_code, thought_process
535528
)
536529

537-
if cls.plot:
538-
plt.show()
530+
#if cls.plot:
531+
# plt.show()
539532

540533
# deduplicate as both events and plot could append plots
541534
return amadeus_answer
@@ -608,9 +601,7 @@ def chat_iteration(cls, user_input, code_output="", functions=None, rephrased=[]
608601
code_output=code_output,
609602
functions=functions,
610603
)
611-
if not isinstance(answer, AmadeusAnswer):
612-
answer = AmadeusAnswer.fromdict(answer)
613-
604+
614605
return answer
615606

616607
# @classmethod
@@ -653,19 +644,19 @@ def compile_amadeus_answer_when_no_error(
653644
):
654645
explanation = cls.explainer_brain.generate_explanation(
655646
user_query,
656-
amadeus_answer["chain_of_thoughts"],
657-
amadeus_answer["str_answer"],
658-
amadeus_answer["plots"]
647+
amadeus_answer.chain_of_thoughts,
648+
amadeus_answer.str_answer,
649+
amadeus_answer.plots
659650
)
660-
amadeus_answer["summary"] = explanation
651+
amadeus_answer.summary = explanation
661652
AmadeusLogger.info("Generated explanation from the explainer:")
662653
AmadeusLogger.info(explanation)
663654
else:
664-
amadeus_answer["summary"] = ""
655+
amadeus_answer.summary = ""
665656
else:
666657
# if gpt apologies or asks for clarification, it will be no error but no function
667658
amadeus_answer = AmadeusAnswer()
668-
amadeus_answer["chain_of_thoughts"] = thought_process
659+
amadeus_answer.chain_of_thoughts = thought_process
669660
return amadeus_answer
670661

671662
@classmethod

0 commit comments

Comments
 (0)