Skip to content

Commit 0c9f193

Browse files
authored
feat: add enable_cache toggle for UI data caching (#1075)
1 parent 56ba15a commit 0c9f193

File tree

3 files changed

+12
-2
lines changed

3 files changed

+12
-2
lines changed

rdagent/log/ui/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,7 @@ class UIBasePropSetting(ExtendedBaseSettings):
1818

1919
trace_folder: str = "./traces"
2020

21+
enable_cache: bool = True
22+
2123

2224
UI_SETTING = UIBasePropSetting()

rdagent/log/ui/ds_trace.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from rdagent.app.data_science.loop import DataScienceRDLoop
1616
from rdagent.log.storage import FileStorage
17+
from rdagent.log.ui.conf import UI_SETTING
1718
from rdagent.log.ui.utils import curve_figure, load_times, trace_figure
1819
from rdagent.log.utils import (
1920
LogColors,
@@ -45,7 +46,6 @@ def convert_defaultdict_to_dict(d):
4546
return d
4647

4748

48-
@st.cache_data(persist=True)
4949
def load_data(log_path: Path):
5050
data = defaultdict(lambda: defaultdict(dict))
5151
llm_data = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
@@ -132,6 +132,10 @@ def load_data(log_path: Path):
132132
)
133133

134134

135+
if UI_SETTING.enable_cache:
136+
load_data = st.cache_data(persist=True)(load_data)
137+
138+
135139
def load_stdout(stdout_path: Path):
136140
if stdout_path.exists():
137141
stdout = stdout_path.read_text()

rdagent/log/ui/llm_st.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import streamlit as st
99
from streamlit import session_state
1010

11+
from rdagent.log.ui.conf import UI_SETTING
1112
from rdagent.log.utils import extract_evoid, extract_loopid_func_name
1213

1314
st.set_page_config(layout="wide", page_title="debug_llm", page_icon="🎓", initial_sidebar_state="expanded")
@@ -18,7 +19,6 @@
1819
args = parser.parse_args()
1920

2021

21-
@st.cache_data
2222
def get_folders_sorted(log_path):
2323
"""缓存并返回排序后的文件夹列表,并加入进度打印"""
2424
with st.spinner("正在加载文件夹列表..."):
@@ -31,6 +31,10 @@ def get_folders_sorted(log_path):
3131
return [folder.name for folder in folders]
3232

3333

34+
if UI_SETTING.enable_cache:
35+
get_folders_sorted = st.cache_data(get_folders_sorted)
36+
37+
3438
# 设置主日志路径
3539
main_log_path = Path(args.log_dir) if args.log_dir else Path("./log")
3640
if not main_log_path.exists():

0 commit comments

Comments
 (0)