Skip to content

Commit b7c0cf5

Browse files
authored
chore: show merge loop in figure (#1105)
1 parent f1ef140 commit b7c0cf5

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

rdagent/log/ui/ds_trace.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,15 @@ def summarize_win():
588588
final_trace_loop_id = max_id
589589
while "record" not in state.data[final_trace_loop_id]:
590590
final_trace_loop_id -= 1
591-
st.pyplot(trace_figure(state.data[final_trace_loop_id]["record"]["trace"]))
591+
merge_loops = []
592+
for loop_id in state.llm_data.keys():
593+
if "direct_exp_gen" not in state.llm_data[loop_id]:
594+
continue
595+
if "scenarios.data_science.proposal.exp_gen.merge:trace" in [
596+
i["obj"]["uri"] for i in state.llm_data[loop_id]["direct_exp_gen"]["no_tag"] if "uri" in i["obj"]
597+
]:
598+
merge_loops.append(loop_id)
599+
st.pyplot(trace_figure(state.data[final_trace_loop_id]["record"]["trace"], merge_loops))
592600
df = pd.DataFrame(
593601
columns=[
594602
"Component",

rdagent/log/ui/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def curve_figure(scores: pd.DataFrame) -> go.Figure:
598598
return fig
599599

600600

601-
def trace_figure(trace: Trace):
601+
def trace_figure(trace: Trace, merge_loops: list = []):
602602
G = nx.DiGraph()
603603

604604
# Calculate the number of ancestors for each node (root node is 0, more ancestors means lower level)
@@ -610,7 +610,7 @@ def get_display_name(idx: int):
610610
"""
611611
Convert to index in the queue (enque id) to loop_idx for easier understanding.
612612
"""
613-
if hasattr(trace, "idx2loop_id"):
613+
if hasattr(trace, "idx2loop_id") and idx in trace.idx2loop_id:
614614
# FIXME: only keep me after it is stable. Just for compatibility.
615615
return f"L{trace.idx2loop_id[idx]} ({idx})"
616616
return f"L{idx}"
@@ -684,7 +684,8 @@ def parent_avg_pos(node):
684684
pos[node] = (x, y)
685685

686686
fig, ax = plt.subplots(figsize=(8, 6))
687-
nx.draw(G, pos, with_labels=True, arrows=True, node_color="skyblue", node_size=100, font_size=5, ax=ax)
687+
color_map = ["tomato" if node in [get_display_name(idx) for idx in merge_loops] else "skyblue" for node in G]
688+
nx.draw(G, pos, with_labels=True, arrows=True, node_color=color_map, node_size=100, font_size=5, ax=ax)
688689
return fig
689690

690691

0 commit comments

Comments
 (0)