Skip to content

Commit 280e3b7

Browse files
authored
add parent & root node in trace page (#1164)
1 parent 4e41c97 commit 280e3b7

File tree

2 files changed

+99
-86
lines changed

2 files changed

+99
-86
lines changed

rdagent/log/ui/ds_summary.py

Lines changed: 30 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -154,50 +154,7 @@ def apply_func(cdf: pd.DataFrame):
154154
base_df["Select"] = base_df.index.isin(best_idxs.values)
155155

156156
base_df = st.data_editor(
157-
base_df.style.apply(
158-
lambda col: col.map(lambda val: "background-color: #F0F8FF"),
159-
subset=[
160-
"Baseline Score",
161-
"Bronze Threshold",
162-
"Silver Threshold",
163-
"Gold Threshold",
164-
"Medium Threshold",
165-
],
166-
axis=0,
167-
)
168-
.apply(
169-
lambda col: col.map(lambda val: "background-color: #FFFFE0"),
170-
subset=[
171-
"Ours - Base",
172-
"Ours vs Base",
173-
"Ours vs Bronze",
174-
"Ours vs Silver",
175-
"Ours vs Gold",
176-
],
177-
axis=0,
178-
)
179-
.apply(
180-
lambda col: col.map(lambda val: "background-color: #E6E6FA"),
181-
subset=[
182-
"Script Time",
183-
"Exec Time",
184-
"Exp Gen",
185-
"Coding",
186-
"Running",
187-
],
188-
axis=0,
189-
)
190-
.apply(
191-
lambda col: col.map(lambda val: "background-color: #F0FFF0"),
192-
subset=[
193-
"Best Result",
194-
"SOTA Exp (to_submit)",
195-
"SOTA LID (to_submit)",
196-
"SOTA Exp Score (to_submit)",
197-
"SOTA Exp Score (valid, to_submit)",
198-
],
199-
axis=0,
200-
),
157+
base_df,
201158
column_config={
202159
"Select": st.column_config.CheckboxColumn("Select", help="Stat this trace.", disabled=False),
203160
},
@@ -218,32 +175,43 @@ def apply_func(cdf: pd.DataFrame):
218175
st.text(markdown_table)
219176
with stat_win_right:
220177
Loop_counts = base_df["Total Loops"]
221-
fig = px.histogram(Loop_counts, nbins=10, title="Total Loops Histogram (nbins=10)")
178+
179+
# Create histogram
180+
fig = px.histogram(
181+
Loop_counts, nbins=15, title="Distribution of Total Loops", color_discrete_sequence=["#3498db"]
182+
)
183+
fig.update_layout(title_font_size=16, title_font_color="#2c3e50")
184+
185+
# Calculate statistics
222186
mean_value = Loop_counts.mean()
223187
median_value = Loop_counts.median()
224-
fig.add_vline(
225-
x=mean_value,
226-
line_color="orange",
227-
annotation_text="Mean",
228-
annotation_position="top right",
229-
line_width=3,
230-
)
231-
fig.add_vline(
232-
x=median_value,
233-
line_color="red",
234-
annotation_text="Median",
235-
annotation_position="top right",
236-
line_width=3,
188+
189+
# Add mean and median lines
190+
fig.add_vline(x=mean_value, line_color="#e74c3c", line_width=3)
191+
fig.add_vline(x=median_value, line_color="#f39c12", line_width=3)
192+
193+
fig.add_annotation(
194+
x=0.02,
195+
y=0.95,
196+
xref="paper",
197+
yref="paper",
198+
text=f"<span style='color:#e74c3c; font-weight:bold'>Mean: {mean_value:.1f}</span><br><span style='color:#f39c12; font-weight:bold'>Median: {median_value:.1f}</span>",
199+
showarrow=False,
200+
bgcolor="rgba(255,255,255,0.9)",
201+
bordercolor="rgba(128,128,128,0.5)",
202+
borderwidth=1,
203+
font=dict(size=12, color="#333333"),
237204
)
238-
st.plotly_chart(fig)
205+
206+
st.plotly_chart(fig, use_container_width=True)
239207

240208
# write curve
241209
st.subheader("Curves", divider="rainbow")
242210
curves_win(summary)
243211

244212

245-
with st.container(border=True):
246-
if st.toggle("近3天平均", key="show_3days"):
247-
days_summarize_win()
213+
# with st.container(border=True):
214+
# if st.toggle("近3天平均", key="show_3days"):
215+
# days_summarize_win()
248216
with st.container(border=True):
249217
all_summarize_win()

rdagent/log/ui/ds_trace.py

Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,6 @@ def summarize_win():
662662
info5.metric(
663663
"LLM Filter Calls",
664664
llm_filter_call,
665-
delta=-round(llm_filter_call / llm_call, 5),
666665
help=timedelta_to_str(filter_call_duration),
667666
)
668667

@@ -688,11 +687,9 @@ def summarize_win():
688687
help=f"{timeout_stats['running']['timeout']}/{timeout_stats['running']['total']}",
689688
)
690689

690+
final_trace = list(FileStorage(state.log_folder / state.log_path).iter_msg(tag="record.trace"))[-1].content
691691
if show_trace_dag:
692692
st.markdown("### Trace DAG")
693-
final_trace_loop_id = max_id
694-
while "record" not in state.data[final_trace_loop_id]:
695-
final_trace_loop_id -= 1
696693
merge_loops = []
697694
for loop_id in state.llm_data.keys():
698695
if "direct_exp_gen" not in state.llm_data[loop_id]:
@@ -701,15 +698,32 @@ def summarize_win():
701698
i["obj"]["uri"] for i in state.llm_data[loop_id]["direct_exp_gen"]["no_tag"] if "uri" in i["obj"]
702699
]:
703700
merge_loops.append(loop_id)
704-
st.pyplot(trace_figure(state.data[final_trace_loop_id]["record"]["trace"], merge_loops))
701+
st.pyplot(trace_figure(final_trace, merge_loops))
702+
703+
# Find all root nodes (for grouping loops by trace)
704+
root_nodes = {}
705+
parent_nodes = {}
706+
for node in range(len(final_trace.hist)):
707+
parents = final_trace.get_parents(node)
708+
root_nodes[node] = parents[0]
709+
parent_nodes[node] = parents[-2] if len(parents) > 1 else None
710+
root_nodes = {final_trace.idx2loop_id[n]: final_trace.idx2loop_id[r] for n, r in root_nodes.items()}
711+
parent_nodes = {
712+
final_trace.idx2loop_id[n]: final_trace.idx2loop_id[r] if r is not None else r
713+
for n, r in parent_nodes.items()
714+
}
715+
716+
# Generate Summary Table
705717
df = pd.DataFrame(
706718
columns=[
719+
"Root N",
720+
"Parent N",
707721
"Component",
708722
"Hypothesis",
709723
"Reason",
710724
"Others",
711-
"Running Score (valid)",
712-
"Running Score (test)",
725+
"Run Score (valid)",
726+
"Run Score (test)",
713727
"Feedback",
714728
"e-loops(c)",
715729
"e-loops(r)",
@@ -726,6 +740,8 @@ def summarize_win():
726740
sota_loop_id = state.sota_info[1] if state.sota_info else None
727741
for loop in range(min_id, max_id + 1):
728742
loop_data = state.data[loop]
743+
df.loc[loop, "Parent N"] = parent_nodes.get(loop, None)
744+
df.loc[loop, "Root N"] = root_nodes.get(loop, None)
729745
df.loc[loop, "Component"] = loop_data["direct_exp_gen"]["no_tag"].hypothesis.component
730746
df.loc[loop, "Hypothesis"] = loop_data["direct_exp_gen"]["no_tag"].hypothesis.hypothesis
731747
df.loc[loop, "Reason"] = loop_data["direct_exp_gen"]["no_tag"].hypothesis.reason
@@ -766,10 +782,10 @@ def summarize_win():
766782
running_result = loop_data["running"]["no_tag"].result
767783
except AttributeError as e: # Compatible with old versions
768784
running_result = loop_data["running"]["no_tag"].__dict__["result"]
769-
df.loc[loop, "Running Score (valid)"] = str(round(running_result.loc["ensemble"].iloc[0], 5))
785+
df.loc[loop, "Run Score (valid)"] = str(round(running_result.loc["ensemble"].iloc[0], 5))
770786
valid_results[loop] = running_result
771787
except:
772-
df.loc[loop, "Running Score (valid)"] = "❌"
788+
df.loc[loop, "Run Score (valid)"] = "❌"
773789
if "mle_score" not in state.data[loop]:
774790
if "mle_score" in loop_data["running"]:
775791
mle_score_txt = loop_data["running"]["mle_score"]
@@ -787,12 +803,10 @@ def summarize_win():
787803
else "🥉" if state.data[loop]["mle_score"]["bronze_medal"] else ""
788804
)
789805
)
790-
df.loc[loop, "Running Score (test)"] = (
791-
f"{medal_emoji} {state.data[loop]['mle_score']['score']}"
792-
)
806+
df.loc[loop, "Run Score (test)"] = f"{medal_emoji} {state.data[loop]['mle_score']['score']}"
793807
else:
794808
state.data[loop]["mle_score"] = mle_score_txt
795-
df.loc[loop, "Running Score (test)"] = "❌"
809+
df.loc[loop, "Run Score (test)"] = "❌"
796810
else:
797811
mle_score_path = (
798812
replace_ep_path(loop_data["running"]["no_tag"].experiment_workspace.workspace_path)
@@ -811,15 +825,15 @@ def summarize_win():
811825
else "🥉" if state.data[loop]["mle_score"]["bronze_medal"] else ""
812826
)
813827
)
814-
df.loc[loop, "Running Score (test)"] = (
828+
df.loc[loop, "Run Score (test)"] = (
815829
f"{medal_emoji} {state.data[loop]['mle_score']['score']}"
816830
)
817831
else:
818832
state.data[loop]["mle_score"] = mle_score_txt
819-
df.loc[loop, "Running Score (test)"] = "❌"
833+
df.loc[loop, "Run Score (test)"] = "❌"
820834
except Exception as e:
821835
state.data[loop]["mle_score"] = str(e)
822-
df.loc[loop, "Running Score (test)"] = "❌"
836+
df.loc[loop, "Run Score (test)"] = "❌"
823837
else:
824838
if isinstance(state.data[loop]["mle_score"], dict):
825839
medal_emoji = (
@@ -831,13 +845,13 @@ def summarize_win():
831845
else "🥉" if state.data[loop]["mle_score"]["bronze_medal"] else ""
832846
)
833847
)
834-
df.loc[loop, "Running Score (test)"] = f"{medal_emoji} {state.data[loop]['mle_score']['score']}"
848+
df.loc[loop, "Run Score (test)"] = f"{medal_emoji} {state.data[loop]['mle_score']['score']}"
835849
else:
836-
df.loc[loop, "Running Score (test)"] = "❌"
850+
df.loc[loop, "Run Score (test)"] = "❌"
837851

838852
else:
839-
df.loc[loop, "Running Score (valid)"] = "N/A"
840-
df.loc[loop, "Running Score (test)"] = "N/A"
853+
df.loc[loop, "Run Score (valid)"] = "N/A"
854+
df.loc[loop, "Run Score (test)"] = "N/A"
841855

842856
if "coding" in loop_data:
843857
if len([i for i in loop_data["coding"].keys() if isinstance(i, int)]) == 0:
@@ -859,7 +873,38 @@ def summarize_win():
859873

860874
if only_success:
861875
df = df[df["Feedback"] == "✅"]
862-
st.dataframe(df[df.columns[~df.columns.isin(["Hypothesis", "Reason", "Others"])]])
876+
877+
# Add color styling based on root_nodes
878+
def style_dataframe_by_root(df, root_nodes):
879+
# Create a color map for different root nodes - using colors that work well in both light and dark modes
880+
unique_roots = list(set(root_nodes.values()))
881+
colors = [
882+
"rgba(255, 99, 132, 0.3)",
883+
"rgba(54, 162, 235, 0.3)",
884+
"rgba(75, 192, 75, 0.3)",
885+
"rgba(255, 159, 64, 0.3)",
886+
"rgba(153, 102, 255, 0.2)",
887+
"rgba(255, 205, 86, 0.2)",
888+
"rgba(199, 199, 199, 0.2)",
889+
"rgba(83, 102, 255, 0.2)",
890+
]
891+
root_color_map = {root: colors[i % len(colors)] for i, root in enumerate(unique_roots)}
892+
893+
# Create styling function
894+
def apply_color(row):
895+
loop_id = row.name
896+
if loop_id in root_nodes:
897+
root_id = root_nodes[loop_id]
898+
color = root_color_map.get(root_id, "rgba(128, 128, 128, 0.1)")
899+
return [f"background-color: {color}"] * len(row)
900+
return [""] * len(row)
901+
902+
return df.style.apply(apply_color, axis=1)
903+
904+
styled_df = style_dataframe_by_root(
905+
df[df.columns[~df.columns.isin(["Hypothesis", "Reason", "Others"])]], root_nodes
906+
)
907+
st.dataframe(styled_df)
863908

864909
# timeline figure
865910
if state.times:
@@ -882,7 +927,7 @@ def summarize_win():
882927
ensemble_row = vscores.loc[["ensemble"]]
883928
vscores = pd.concat([ensemble_row, vscores.drop("ensemble")])
884929
vscores = vscores.T
885-
test_scores = df["Running Score (test)"].str.replace(r"[🥇🥈🥉]\s*", "", regex=True)
930+
test_scores = df["Run Score (test)"].str.replace(r"[🥇🥈🥉]\s*", "", regex=True)
886931
vscores["test"] = test_scores
887932
vscores.index = [f"L{i}" for i in vscores.index]
888933
vscores.columns.name = metric_name
@@ -902,7 +947,7 @@ def summarize_win():
902947

903948
def comp_stat_func(x: pd.DataFrame):
904949
total_num = x.shape[0]
905-
valid_num = x[x["Running Score (test)"] != "N/A"].shape[0]
950+
valid_num = x[x["Run Score (test)"] != "N/A"].shape[0]
906951
success_num = x[x["Feedback"] == "✅"].shape[0]
907952
avg_e_loops = x["e-loops(c)"].mean()
908953
return pd.Series(
@@ -920,7 +965,7 @@ def comp_stat_func(x: pd.DataFrame):
920965

921966
# component statistics
922967
comp_df = (
923-
df.loc[:, ["Component", "Running Score (test)", "Feedback", "e-loops(c)"]]
968+
df.loc[:, ["Component", "Run Score (test)", "Feedback", "e-loops(c)"]]
924969
.groupby("Component")
925970
.apply(comp_stat_func, include_groups=False)
926971
)

0 commit comments

Comments
 (0)