Skip to content

Commit 6daae96

Browse files
committed
fix
1 parent 56d8a24 commit 6daae96

File tree

1 file changed

+146
-4
lines changed

1 file changed

+146
-4
lines changed

toolchain/mfc/sched.py

Lines changed: 146 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,21 @@
44

55
from .printer import cons
66

7+
# Thresholds for long-running test notifications
8+
# Interactive mode: dimension-aware thresholds
9+
INTERACTIVE_THRESHOLDS = {
10+
1: 30.0, # 1D: 30 seconds
11+
2: 60.0, # 2D: 1 minute
12+
3: 120.0, # 3D: 2 minutes
13+
}
14+
15+
# Headless mode: fixed time-based thresholds (regardless of dimensionality)
16+
HEADLESS_THRESHOLDS = (
17+
(2 * 60, "[italic yellow]Still running[/italic yellow] (>2min)"),
18+
(10 * 60, "[italic yellow]Still running[/italic yellow] (>10min)"),
19+
(30 * 60, "[bold red]Still running[/bold red] (>30min, may be hanging)"),
20+
)
21+
722
class WorkerThread(threading.Thread):
823
def __init__(self, *args, **kwargs):
924
self.exc = None
@@ -29,6 +44,13 @@ class WorkerThreadHolder:
2944
ppn: int
3045
load: float
3146
devices: typing.Set[int]
47+
task: typing.Optional['Task'] = None
48+
start: float = 0.0
49+
# Track which milestones we've already logged
50+
notified_30s: bool = False # for interactive mode
51+
notified_2m: bool = False
52+
notified_10m: bool = False
53+
notified_30m: bool = False
3254

3355

3456
@dataclasses.dataclass
@@ -44,7 +66,99 @@ def sched(tasks: typing.List[Task], nThreads: int, devices: typing.Set[int] = No
4466

4567
sched.LOAD = { id: 0.0 for id in devices or [] }
4668

47-
def join_first_dead_thread(progress, complete_tracker) -> None:
69+
def get_case_dimensionality(case) -> int:
70+
"""Determine if a test case is 1D, 2D, or 3D based on m, n, p parameters."""
71+
if not hasattr(case, 'params'):
72+
return 1 # Default to 1D if we can't determine
73+
74+
params = case.params
75+
p = params.get('p', 0)
76+
n = params.get('n', 0)
77+
78+
if p != 0:
79+
return 3 # 3D
80+
elif n != 0:
81+
return 2 # 2D
82+
else:
83+
return 1 # 1D
84+
85+
def get_threshold_for_case(case, interactive: bool) -> float:
86+
"""Get the appropriate threshold for a test case."""
87+
if interactive:
88+
dim = get_case_dimensionality(case)
89+
return INTERACTIVE_THRESHOLDS.get(dim, INTERACTIVE_THRESHOLDS[1])
90+
else:
91+
# Headless mode uses fixed thresholds
92+
return HEADLESS_THRESHOLDS[0][0] # 2 minutes
93+
94+
def notify_long_running_threads(progress, running_tracker, interactive: bool) -> None:
95+
now = time.time()
96+
long_running_for_progress = []
97+
98+
for holder in threads:
99+
if not holder.thread.is_alive():
100+
continue
101+
102+
elapsed = now - holder.start
103+
case = holder.task.args[0] if holder.task and holder.task.args else None
104+
case_uuid = case.get_uuid() if hasattr(case, "get_uuid") else "unknown"
105+
case_trace = getattr(case, "trace", "")
106+
107+
# --- interactive: dimension-aware thresholds ---
108+
if interactive:
109+
threshold = get_threshold_for_case(case, interactive=True)
110+
111+
if elapsed >= threshold:
112+
long_running_for_progress.append((case_uuid, case_trace))
113+
114+
# Print explicit line once when crossing threshold
115+
if not holder.notified_30s:
116+
dim = get_case_dimensionality(case)
117+
dim_label = f"{dim}D"
118+
time_label = f"{int(threshold)}s" if threshold < 60 else f"{int(threshold/60)}min"
119+
cons.print(
120+
f" [italic yellow]Still running[/italic yellow] ({dim_label}, >{time_label}) "
121+
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
122+
)
123+
holder.notified_30s = True
124+
125+
# --- headless: milestone notifications at 2, 10, 30 minutes ---
126+
else:
127+
# 2 minutes
128+
if (not holder.notified_2m) and elapsed >= 2 * 60:
129+
cons.print(
130+
f" {HEADLESS_THRESHOLDS[0][1]} "
131+
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
132+
)
133+
holder.notified_2m = True
134+
135+
# 10 minutes
136+
if (not holder.notified_10m) and elapsed >= 10 * 60:
137+
cons.print(
138+
f" {HEADLESS_THRESHOLDS[1][1]} "
139+
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
140+
)
141+
holder.notified_10m = True
142+
143+
# 30 minutes
144+
if (not holder.notified_30m) and elapsed >= 30 * 60:
145+
cons.print(
146+
f" {HEADLESS_THRESHOLDS[2][1]} "
147+
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
148+
)
149+
holder.notified_30m = True
150+
151+
# update the interactive "Running" row
152+
if interactive and running_tracker is not None:
153+
if long_running_for_progress:
154+
summary = ", ".join(uuid for uuid, _ in long_running_for_progress[:5])
155+
if len(long_running_for_progress) > 5:
156+
summary += f", +{len(long_running_for_progress) - 5} more"
157+
progress.update(running_tracker, description=f"Running (long): {summary}")
158+
else:
159+
progress.update(running_tracker, description="Running (long): -")
160+
161+
def join_first_dead_thread(progress, complete_tracker, interactive: bool) -> None:
48162
nonlocal threads, nAvailable
49163

50164
for threadID, threadHolder in enumerate(threads):
@@ -82,6 +196,17 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
82196
else:
83197
raise threadHolder.thread.exc
84198

199+
# Print completion message for long-running tests in interactive mode
200+
if interactive and threadHolder.notified_30s:
201+
elapsed = time.time() - threadHolder.start
202+
case = threadHolder.task.args[0] if threadHolder.task and threadHolder.task.args else None
203+
case_uuid = case.get_uuid() if hasattr(case, "get_uuid") else "unknown"
204+
case_trace = getattr(case, "trace", "")
205+
cons.print(
206+
f" [italic green]Completed[/italic green] (after {elapsed:.1f}s) "
207+
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
208+
)
209+
85210
nAvailable += threadHolder.ppn
86211
for device in threadHolder.devices or set():
87212
sched.LOAD[device] -= threadHolder.load / threadHolder.ppn
@@ -93,8 +218,10 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
93218
break
94219

95220
with rich.progress.Progress(console=cons.raw, transient=True) as progress:
221+
interactive = cons.raw.is_terminal
96222
queue_tracker = progress.add_task("Queued ", total=len(tasks))
97223
complete_tracker = progress.add_task("Completed", total=len(tasks))
224+
running_tracker = progress.add_task("Running ", total=None) if interactive else None
98225

99226
# Queue Tests
100227
for task in tasks:
@@ -106,7 +233,10 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
106233
break
107234

108235
# Keep track of threads that are done
109-
join_first_dead_thread(progress, complete_tracker)
236+
join_first_dead_thread(progress, complete_tracker, interactive)
237+
238+
# Notify about long-running threads
239+
notify_long_running_threads(progress, running_tracker, interactive)
110240

111241
# Do not overwhelm this core with this loop
112242
time.sleep(0.05)
@@ -128,12 +258,24 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
128258
thread = WorkerThread(target=task.func, args=tuple(task.args) + (use_devices,))
129259
thread.start()
130260

131-
threads.append(WorkerThreadHolder(thread, task.ppn, task.load, use_devices))
261+
threads.append(
262+
WorkerThreadHolder(
263+
thread=thread,
264+
ppn=task.ppn,
265+
load=task.load,
266+
devices=use_devices,
267+
task=task,
268+
start=time.time(),
269+
)
270+
)
132271

133272
# Wait for the last tests to complete (MOVED INSIDE CONTEXT)
134273
while len(threads) != 0:
135274
# Keep track of threads that are done
136-
join_first_dead_thread(progress, complete_tracker)
275+
join_first_dead_thread(progress, complete_tracker, interactive)
276+
277+
# Notify about long-running threads
278+
notify_long_running_threads(progress, running_tracker, interactive)
137279

138280
# Do not overwhelm this core with this loop
139281
time.sleep(0.05)

0 commit comments

Comments
 (0)