Skip to content

Commit e2ce07d

Browse files
authored
Add dimension-aware long-running test notifications (#1067)
1 parent 987888c commit e2ce07d

File tree

1 file changed

+165
-8
lines changed

1 file changed

+165
-8
lines changed

toolchain/mfc/sched.py

Lines changed: 165 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,21 @@
44

55
from .printer import cons
66

7+
# Thresholds for long-running test notifications
8+
# Interactive mode: dimension-aware thresholds
9+
INTERACTIVE_THRESHOLDS = {
10+
1: 30.0, # 1D: 30 seconds
11+
2: 60.0, # 2D: 1 minute
12+
3: 120.0, # 3D: 2 minutes
13+
}
14+
15+
# Headless mode: fixed time-based thresholds (regardless of dimensionality)
16+
HEADLESS_THRESHOLDS = (
17+
(2 * 60, "[italic yellow]Still running[/italic yellow] (>2min)"),
18+
(10 * 60, "[italic yellow]Still running[/italic yellow] (>10min)"),
19+
(30 * 60, "[bold red]Still running[/bold red] (>30min, may be hanging)"),
20+
)
21+
722
class WorkerThread(threading.Thread):
823
def __init__(self, *args, **kwargs):
924
self.exc = None
@@ -24,11 +39,18 @@ def run(self):
2439

2540

2641
@dataclasses.dataclass
27-
class WorkerThreadHolder:
42+
class WorkerThreadHolder: # pylint: disable=too-many-instance-attributes
2843
thread: threading.Thread
2944
ppn: int
3045
load: float
3146
devices: typing.Set[int]
47+
task: typing.Optional['Task'] = None
48+
start: float = 0.0
49+
# Track which milestones we've already logged
50+
notified_interactive: bool = False # First notification in interactive mode (time varies by dimensionality)
51+
notified_2m: bool = False # Headless mode: 2 minute milestone
52+
notified_10m: bool = False # Headless mode: 10 minute milestone
53+
notified_30m: bool = False # Headless mode: 30 minute milestone
3254

3355

3456
@dataclasses.dataclass
@@ -38,13 +60,120 @@ class Task:
3860
args: typing.List[typing.Any]
3961
load: float
4062

41-
def sched(tasks: typing.List[Task], nThreads: int, devices: typing.Set[int] = None) -> None:
63+
def sched(tasks: typing.List[Task], nThreads: int, devices: typing.Set[int] = None) -> None: # pylint: disable=too-many-locals,too-many-statements
4264
nAvailable: int = nThreads
4365
threads: typing.List[WorkerThreadHolder] = []
4466

4567
sched.LOAD = { id: 0.0 for id in devices or [] }
4668

47-
def join_first_dead_thread(progress, complete_tracker) -> None:
69+
def get_case_dimensionality(case: typing.Any) -> int:
70+
"""
71+
Determine if a test case is 1D, 2D, or 3D based on grid parameters.
72+
73+
Grid parameters (m, n, p) represent the number of cells in x, y, z directions.
74+
Returns 3 if p != 0, 2 if n != 0, otherwise 1. Defaults to 1D if params unavailable.
75+
"""
76+
if not hasattr(case, 'params'):
77+
return 1 # Default to 1D if we can't determine
78+
79+
params = case.params
80+
p = params.get('p', 0)
81+
n = params.get('n', 0)
82+
83+
if p != 0:
84+
return 3 # 3D
85+
if n != 0:
86+
return 2 # 2D
87+
return 1 # 1D
88+
89+
def get_threshold_for_case(case: typing.Any) -> float:
90+
"""
91+
Get the dimension-aware time threshold (in seconds) for interactive mode notifications.
92+
93+
Returns 30s for 1D, 60s for 2D, 120s for 3D tests.
94+
"""
95+
dim = get_case_dimensionality(case)
96+
return INTERACTIVE_THRESHOLDS.get(dim, INTERACTIVE_THRESHOLDS[1])
97+
98+
def notify_long_running_threads( # pylint: disable=too-many-branches
99+
progress: rich.progress.Progress,
100+
running_tracker: typing.Optional[rich.progress.TaskID],
101+
interactive: bool
102+
) -> None:
103+
"""
104+
Monitor and notify about long-running tests.
105+
106+
In interactive mode: prints once when a test crosses its dimension-aware threshold
107+
and updates the live progress bar. In headless mode: prints milestone notifications
108+
at 2, 10, and 30 minutes.
109+
"""
110+
now = time.time()
111+
long_running_for_progress = []
112+
113+
for holder in threads:
114+
if not holder.thread.is_alive():
115+
continue
116+
117+
elapsed = now - holder.start
118+
case = holder.task.args[0] if holder.task and holder.task.args else None
119+
case_uuid = case.get_uuid() if hasattr(case, "get_uuid") else "unknown"
120+
case_trace = getattr(case, "trace", "")
121+
122+
# --- interactive: dimension-aware thresholds ---
123+
if interactive:
124+
threshold = get_threshold_for_case(case)
125+
126+
if elapsed >= threshold:
127+
long_running_for_progress.append((case_uuid, case_trace))
128+
129+
# Print explicit line once when crossing threshold
130+
if not holder.notified_interactive:
131+
dim = get_case_dimensionality(case)
132+
dim_label = f"{dim}D"
133+
time_label = f"{int(threshold)}s" if threshold < 60 else f"{threshold/60:.0f}min"
134+
cons.print(
135+
f" [italic yellow]Still running[/italic yellow] ({dim_label}, >{time_label}) "
136+
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
137+
)
138+
holder.notified_interactive = True
139+
140+
# --- headless: milestone notifications at 2, 10, 30 minutes ---
141+
else:
142+
# 2 minutes
143+
if (not holder.notified_2m) and elapsed >= 2 * 60:
144+
cons.print(
145+
f" {HEADLESS_THRESHOLDS[0][1]} "
146+
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
147+
)
148+
holder.notified_2m = True
149+
150+
# 10 minutes
151+
if (not holder.notified_10m) and elapsed >= 10 * 60:
152+
cons.print(
153+
f" {HEADLESS_THRESHOLDS[1][1]} "
154+
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
155+
)
156+
holder.notified_10m = True
157+
158+
# 30 minutes
159+
if (not holder.notified_30m) and elapsed >= 30 * 60:
160+
cons.print(
161+
f" {HEADLESS_THRESHOLDS[2][1]} "
162+
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
163+
)
164+
holder.notified_30m = True
165+
166+
# update the interactive "Running" row
167+
if interactive and running_tracker is not None:
168+
if long_running_for_progress:
169+
summary = ", ".join(uuid for uuid, _ in long_running_for_progress[:5])
170+
if len(long_running_for_progress) > 5:
171+
summary += f", +{len(long_running_for_progress) - 5} more"
172+
progress.update(running_tracker, description=f"Running (long): {summary}")
173+
else:
174+
progress.update(running_tracker, description="Running (long): none")
175+
176+
def join_first_dead_thread(progress, complete_tracker, interactive: bool) -> None:
48177
nonlocal threads, nAvailable
49178

50179
for threadID, threadHolder in enumerate(threads):
@@ -82,6 +211,17 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
82211
else:
83212
raise threadHolder.thread.exc
84213

214+
# Print completion message for long-running tests in interactive mode
215+
if interactive and threadHolder.notified_interactive:
216+
elapsed = time.time() - threadHolder.start
217+
case = threadHolder.task.args[0] if threadHolder.task and threadHolder.task.args else None
218+
case_uuid = case.get_uuid() if hasattr(case, "get_uuid") else "unknown"
219+
case_trace = getattr(case, "trace", "")
220+
cons.print(
221+
f" [italic green]Completed[/italic green] (after {elapsed:.1f}s) "
222+
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
223+
)
224+
85225
nAvailable += threadHolder.ppn
86226
for device in threadHolder.devices or set():
87227
sched.LOAD[device] -= threadHolder.load / threadHolder.ppn
@@ -93,8 +233,10 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
93233
break
94234

95235
with rich.progress.Progress(console=cons.raw, transient=True) as progress:
96-
queue_tracker = progress.add_task("Queued ", total=len(tasks))
97-
complete_tracker = progress.add_task("Completed", total=len(tasks))
236+
interactive = cons.raw.is_terminal
237+
queue_tracker = progress.add_task("Queued ", total=len(tasks))
238+
complete_tracker = progress.add_task("Completed ", total=len(tasks))
239+
running_tracker = progress.add_task("Running ", total=None) if interactive else None
98240

99241
# Queue Tests
100242
for task in tasks:
@@ -106,7 +248,10 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
106248
break
107249

108250
# Keep track of threads that are done
109-
join_first_dead_thread(progress, complete_tracker)
251+
join_first_dead_thread(progress, complete_tracker, interactive)
252+
253+
# Notify about long-running threads
254+
notify_long_running_threads(progress, running_tracker, interactive)
110255

111256
# Do not overwhelm this core with this loop
112257
time.sleep(0.05)
@@ -128,12 +273,24 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
128273
thread = WorkerThread(target=task.func, args=tuple(task.args) + (use_devices,))
129274
thread.start()
130275

131-
threads.append(WorkerThreadHolder(thread, task.ppn, task.load, use_devices))
276+
threads.append(
277+
WorkerThreadHolder(
278+
thread=thread,
279+
ppn=task.ppn,
280+
load=task.load,
281+
devices=use_devices,
282+
task=task,
283+
start=time.time(),
284+
)
285+
)
132286

133287
# Wait for the last tests to complete (MOVED INSIDE CONTEXT)
134288
while len(threads) != 0:
135289
# Keep track of threads that are done
136-
join_first_dead_thread(progress, complete_tracker)
290+
join_first_dead_thread(progress, complete_tracker, interactive)
291+
292+
# Notify about long-running threads
293+
notify_long_running_threads(progress, running_tracker, interactive)
137294

138295
# Do not overwhelm this core with this loop
139296
time.sleep(0.05)

0 commit comments

Comments
 (0)