Skip to content

Commit 234cd58

Browse files
authored
save more time info (#1151)
1 parent 6b48f24 commit 234cd58

File tree

4 files changed

+45
-26
lines changed

4 files changed

+45
-26
lines changed

rdagent/log/ui/ds_trace.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -534,19 +534,17 @@ def main_win(loop_id, llm_data=None):
534534
)
535535
if "running" in loop_data:
536536
# get last SOTA_exp_to_submit
537-
current_trace = loop_data["record"]["trace"]
538-
current_selection = current_trace.get_current_selection()
539-
if len(current_selection) > 0: # TODO: Why current_selection can be "()"?
540-
current_idx = current_selection[0]
541-
parent_idxs = current_trace.get_parents(current_idx)
542-
if len(parent_idxs) >= 2 and hasattr(current_trace, "idx2loop_id"):
543-
parent_idx = parent_idxs[-2]
544-
parent_loop_id = current_trace.idx2loop_id[parent_idx]
545-
sota_exp = state.data[parent_loop_id]["record"].get("sota_exp_to_submit", None)
546-
else:
547-
sota_exp = None
548-
else:
549-
sota_exp = None
537+
sota_exp = None
538+
if "record" in loop_data:
539+
current_trace = loop_data["record"]["trace"]
540+
current_selection = current_trace.get_current_selection()
541+
if len(current_selection) > 0: # TODO: Why current_selection can be "()"?
542+
current_idx = current_selection[0]
543+
parent_idxs = current_trace.get_parents(current_idx)
544+
if len(parent_idxs) >= 2 and hasattr(current_trace, "idx2loop_id"):
545+
parent_idx = parent_idxs[-2]
546+
parent_loop_id = current_trace.idx2loop_id[parent_idx]
547+
sota_exp = state.data[parent_loop_id]["record"].get("sota_exp_to_submit", None)
550548

551549
running_win(
552550
loop_data["running"],

rdagent/oai/backend/base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -257,16 +257,16 @@ def build_chat_completion(self, user_prompt: str, *args, **kwargs) -> str: # ty
257257
messages = self.build_chat_completion_message(user_prompt)
258258

259259
with logger.tag(f"session_{self.conversation_id}"):
260-
start_time = time.time()
260+
start_time = datetime.now(pytz.timezone("Asia/Shanghai"))
261261
response: str = self.api_backend._try_create_chat_completion_or_embedding( # noqa: SLF001
262262
*args,
263263
messages=messages,
264264
chat_completion=True,
265265
**kwargs,
266266
)
267-
end_time = time.time()
267+
end_time = datetime.now(pytz.timezone("Asia/Shanghai"))
268268
logger.log_object(
269-
{"user": user_prompt, "resp": response, "duration": end_time - start_time}, tag="debug_llm"
269+
{"user": user_prompt, "resp": response, "start": start_time, "end": end_time}, tag="debug_llm"
270270
)
271271

272272
messages.append(
@@ -409,19 +409,19 @@ def build_messages_and_create_chat_completion( # type: ignore[no-untyped-def]
409409
shrink_multiple_break=shrink_multiple_break,
410410
)
411411

412-
start_time = time.time()
412+
start_time = datetime.now(pytz.timezone("Asia/Shanghai"))
413413
resp = self._try_create_chat_completion_or_embedding( # type: ignore[misc]
414414
*args,
415415
messages=messages,
416416
chat_completion=True,
417417
chat_cache_prefix=chat_cache_prefix,
418418
**kwargs,
419419
)
420-
end_time = time.time()
420+
end_time = datetime.now(pytz.timezone("Asia/Shanghai"))
421421
if isinstance(resp, list):
422422
raise ValueError("The response of _try_create_chat_completion_or_embedding should be a string.")
423423
logger.log_object(
424-
{"system": system_prompt, "user": user_prompt, "resp": resp, "duration": end_time - start_time},
424+
{"system": system_prompt, "user": user_prompt, "resp": resp, "start": start_time, "end": end_time},
425425
tag="debug_llm",
426426
)
427427
return resp

rdagent/scenarios/data_science/proposal/exp_gen/router/__init__.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import asyncio
4-
from datetime import timedelta
4+
from datetime import datetime, timedelta, timezone
55
from typing import TYPE_CHECKING
66

77
from rdagent.app.data_science.conf import DS_RD_SETTING
@@ -91,12 +91,16 @@ async def async_gen(self, trace: DSTrace, loop: LoopBase) -> DSExperiment:
9191
trace.set_current_selection(local_selection)
9292

9393
ds_plan = self.planner.plan(trace) if DS_RD_SETTING.enable_planner else DSExperimentPlan()
94+
95+
start = datetime.now(timezone.utc)
96+
exp_gen_type = ""
9497
if (
9598
(not timer.started or timer.remain_time() >= timedelta(hours=DS_RD_SETTING.merge_hours))
9699
and trace.sota_experiment(selection=local_selection) is None
97100
and DS_RD_SETTING.enable_draft_before_first_sota
98101
):
99102
exp = self.draft_exp_gen.gen(trace, plan=ds_plan)
103+
exp_gen_type = type(self.draft_exp_gen).__name__
100104
elif (
101105
timer.started
102106
and timer.remain_time() < timedelta(hours=DS_RD_SETTING.merge_hours)
@@ -105,10 +109,20 @@ async def async_gen(self, trace: DSTrace, loop: LoopBase) -> DSExperiment:
105109
DS_RD_SETTING.coding_fail_reanalyze_threshold = 100000
106110
DS_RD_SETTING.consecutive_errors = 100000
107111
exp = self.merge_exp_gen.gen(trace, plan=ds_plan)
112+
exp_gen_type = type(self.merge_exp_gen).__name__
108113
else:
109114
# If there is a sota experiment in the sub-trace and not in merge time, we use default exp_gen
110115
exp = self.exp_gen.gen(trace, plan=ds_plan)
111-
116+
exp_gen_type = type(self.exp_gen).__name__
117+
end = datetime.now(timezone.utc)
118+
logger.log_object(
119+
{
120+
"exp_gen_type": exp_gen_type,
121+
"start_time": start,
122+
"end_time": end,
123+
},
124+
tag="exp_gen_time_info",
125+
)
112126
exp.set_local_selection(local_selection)
113127
exp.plan = ds_plan
114128
return exp

rdagent/utils/workflow/loop.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010

1111
import asyncio
1212
import concurrent.futures
13-
import datetime
1413
import pickle
1514
from collections import defaultdict
1615
from dataclasses import dataclass
16+
from datetime import datetime, timezone
1717
from pathlib import Path
1818
from typing import Any, Callable, Optional, Union, cast
1919

@@ -72,8 +72,8 @@ def __new__(mcs, clsname: str, bases: tuple[type, ...], attrs: dict[str, Any]) -
7272

7373
@dataclass
7474
class LoopTrace:
75-
start: datetime.datetime # the start time of the trace
76-
end: datetime.datetime # the end time of the trace
75+
start: datetime # the start time of the trace
76+
end: datetime # the end time of the trace
7777
step_idx: int
7878
# TODO: more information about the trace
7979

@@ -211,7 +211,7 @@ async def _run_step(self, li: int, force_subproc: bool = False) -> None:
211211
self.tracker.log_workflow_state()
212212

213213
with logger.tag(f"Loop_{li}.{name}"):
214-
start = datetime.datetime.now(datetime.timezone.utc)
214+
start = datetime.now(timezone.utc)
215215
func: Callable[..., Any] = cast(Callable[..., Any], getattr(self, name))
216216

217217
next_step_idx = si + 1
@@ -236,8 +236,15 @@ async def _run_step(self, li: int, force_subproc: bool = False) -> None:
236236
self.loop_prev_out[li][name] = result
237237

238238
# Record the trace
239-
end = datetime.datetime.now(datetime.timezone.utc)
239+
end = datetime.now(timezone.utc)
240240
self.loop_trace[li].append(LoopTrace(start, end, step_idx=si))
241+
logger.log_object(
242+
{
243+
"start_time": start,
244+
"end_time": end,
245+
},
246+
tag="time_info",
247+
)
241248
# Save snapshot after completing the step
242249
self.dump(self.session_folder / f"{li}" / f"{si}_{name}")
243250
except Exception as e:

0 commit comments

Comments
 (0)