Skip to content

Commit abe7de4

Browse files
committed
Merge upstream/master: resolve conflict in sched.py
- Keep typing.Optional[typing.Set[int]] for devices parameter (from our branch) - Integrate dimension-aware long-running test notifications (from upstream PR #1067) - Combine both WorkerThreadHolder fields and sched function signature improvements
2 parents 37565ae + e2ce07d commit abe7de4

File tree

1 file changed

+163
-6
lines changed

1 file changed

+163
-6
lines changed

toolchain/mfc/sched.py

Lines changed: 163 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,21 @@
44

55
from .printer import cons
66

7+
# Thresholds for long-running test notifications
8+
# Interactive mode: dimension-aware thresholds
9+
INTERACTIVE_THRESHOLDS = {
10+
1: 30.0, # 1D: 30 seconds
11+
2: 60.0, # 2D: 1 minute
12+
3: 120.0, # 3D: 2 minutes
13+
}
14+
15+
# Headless mode: fixed time-based thresholds (regardless of dimensionality)
16+
HEADLESS_THRESHOLDS = (
17+
(2 * 60, "[italic yellow]Still running[/italic yellow] (>2min)"),
18+
(10 * 60, "[italic yellow]Still running[/italic yellow] (>10min)"),
19+
(30 * 60, "[bold red]Still running[/bold red] (>30min, may be hanging)"),
20+
)
21+
722
class WorkerThread(threading.Thread):
823
def __init__(self, *args, **kwargs):
924
self.exc = None
@@ -29,6 +44,13 @@ class WorkerThreadHolder: # pylint: disable=too-many-instance-attributes
2944
ppn: int
3045
load: float
3146
devices: typing.Optional[typing.Set[int]]
47+
task: typing.Optional['Task'] = None
48+
start: float = 0.0
49+
# Track which milestones we've already logged
50+
notified_interactive: bool = False # First notification in interactive mode (time varies by dimensionality)
51+
notified_2m: bool = False # Headless mode: 2 minute milestone
52+
notified_10m: bool = False # Headless mode: 10 minute milestone
53+
notified_30m: bool = False # Headless mode: 30 minute milestone
3254

3355

3456
@dataclasses.dataclass
@@ -44,7 +66,114 @@ def sched(tasks: typing.List[Task], nThreads: int, devices: typing.Optional[typi
4466

4567
sched.LOAD = { id: 0.0 for id in devices or [] }
4668

47-
def join_first_dead_thread(progress, complete_tracker) -> None:
69+
def get_case_dimensionality(case: typing.Any) -> int:
70+
"""
71+
Determine if a test case is 1D, 2D, or 3D based on grid parameters.
72+
73+
Grid parameters (m, n, p) represent the number of cells in x, y, z directions.
74+
Returns 3 if p != 0, 2 if n != 0, otherwise 1. Defaults to 1D if params unavailable.
75+
"""
76+
if not hasattr(case, 'params'):
77+
return 1 # Default to 1D if we can't determine
78+
79+
params = case.params
80+
p = params.get('p', 0)
81+
n = params.get('n', 0)
82+
83+
if p != 0:
84+
return 3 # 3D
85+
if n != 0:
86+
return 2 # 2D
87+
return 1 # 1D
88+
89+
def get_threshold_for_case(case: typing.Any) -> float:
90+
"""
91+
Get the dimension-aware time threshold (in seconds) for interactive mode notifications.
92+
93+
Returns 30s for 1D, 60s for 2D, 120s for 3D tests.
94+
"""
95+
dim = get_case_dimensionality(case)
96+
return INTERACTIVE_THRESHOLDS.get(dim, INTERACTIVE_THRESHOLDS[1])
97+
98+
def notify_long_running_threads( # pylint: disable=too-many-branches
99+
progress: rich.progress.Progress,
100+
running_tracker: typing.Optional[rich.progress.TaskID],
101+
interactive: bool
102+
) -> None:
103+
"""
104+
Monitor and notify about long-running tests.
105+
106+
In interactive mode: prints once when a test crosses its dimension-aware threshold
107+
and updates the live progress bar. In headless mode: prints milestone notifications
108+
at 2, 10, and 30 minutes.
109+
"""
110+
now = time.time()
111+
long_running_for_progress = []
112+
113+
for holder in threads:
114+
if not holder.thread.is_alive():
115+
continue
116+
117+
elapsed = now - holder.start
118+
case = holder.task.args[0] if holder.task and holder.task.args else None
119+
case_uuid = case.get_uuid() if hasattr(case, "get_uuid") else "unknown"
120+
case_trace = getattr(case, "trace", "")
121+
122+
# --- interactive: dimension-aware thresholds ---
123+
if interactive:
124+
threshold = get_threshold_for_case(case)
125+
126+
if elapsed >= threshold:
127+
long_running_for_progress.append((case_uuid, case_trace))
128+
129+
# Print explicit line once when crossing threshold
130+
if not holder.notified_interactive:
131+
dim = get_case_dimensionality(case)
132+
dim_label = f"{dim}D"
133+
time_label = f"{int(threshold)}s" if threshold < 60 else f"{threshold/60:.0f}min"
134+
cons.print(
135+
f" [italic yellow]Still running[/italic yellow] ({dim_label}, >{time_label}) "
136+
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
137+
)
138+
holder.notified_interactive = True
139+
140+
# --- headless: milestone notifications at 2, 10, 30 minutes ---
141+
else:
142+
# 2 minutes
143+
if (not holder.notified_2m) and elapsed >= 2 * 60:
144+
cons.print(
145+
f" {HEADLESS_THRESHOLDS[0][1]} "
146+
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
147+
)
148+
holder.notified_2m = True
149+
150+
# 10 minutes
151+
if (not holder.notified_10m) and elapsed >= 10 * 60:
152+
cons.print(
153+
f" {HEADLESS_THRESHOLDS[1][1]} "
154+
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
155+
)
156+
holder.notified_10m = True
157+
158+
# 30 minutes
159+
if (not holder.notified_30m) and elapsed >= 30 * 60:
160+
cons.print(
161+
f" {HEADLESS_THRESHOLDS[2][1]} "
162+
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
163+
)
164+
holder.notified_30m = True
165+
166+
# update the interactive "Running" row
167+
if interactive and running_tracker is not None:
168+
if long_running_for_progress:
169+
summary = ", ".join(uuid for uuid, _ in long_running_for_progress[:5])
170+
if len(long_running_for_progress) > 5:
171+
summary += f", +{len(long_running_for_progress) - 5} more"
172+
progress.update(running_tracker, description=f"Running (long): {summary}")
173+
else:
174+
progress.update(running_tracker, description="Running (long): none")
175+
176+
def join_first_dead_thread(progress, complete_tracker, interactive: bool) -> None:
48177
nonlocal threads, nAvailable
49178

50179
for threadID, threadHolder in enumerate(threads):
@@ -75,6 +204,17 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
75204
raise RuntimeError(error_msg) from threadHolder.thread.exc
76205
raise threadHolder.thread.exc
77206

207+
# Print completion message for long-running tests in interactive mode
208+
if interactive and threadHolder.notified_interactive:
209+
elapsed = time.time() - threadHolder.start
210+
case = threadHolder.task.args[0] if threadHolder.task and threadHolder.task.args else None
211+
case_uuid = case.get_uuid() if hasattr(case, "get_uuid") else "unknown"
212+
case_trace = getattr(case, "trace", "")
213+
cons.print(
214+
f" [italic green]Completed[/italic green] (after {elapsed:.1f}s) "
215+
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
216+
)
217+
78218
nAvailable += threadHolder.ppn
79219
for device in threadHolder.devices or set():
80220
sched.LOAD[device] -= threadHolder.load / threadHolder.ppn
@@ -86,8 +226,10 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
86226
break
87227

88228
with rich.progress.Progress(console=cons.raw, transient=True) as progress:
89-
queue_tracker = progress.add_task("Queued ", total=len(tasks))
90-
complete_tracker = progress.add_task("Completed", total=len(tasks))
229+
interactive = cons.raw.is_terminal
230+
queue_tracker = progress.add_task("Queued ", total=len(tasks))
231+
complete_tracker = progress.add_task("Completed ", total=len(tasks))
232+
running_tracker = progress.add_task("Running ", total=None) if interactive else None
91233

92234
# Queue Tests
93235
for task in tasks:
@@ -99,7 +241,10 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
99241
break
100242

101243
# Keep track of threads that are done
102-
join_first_dead_thread(progress, complete_tracker)
244+
join_first_dead_thread(progress, complete_tracker, interactive)
245+
246+
# Notify about long-running threads
247+
notify_long_running_threads(progress, running_tracker, interactive)
103248

104249
# Do not overwhelm this core with this loop
105250
time.sleep(0.05)
@@ -121,12 +266,24 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
121266
thread = WorkerThread(target=task.func, args=tuple(task.args) + (use_devices,))
122267
thread.start()
123268

124-
threads.append(WorkerThreadHolder(thread, task.ppn, task.load, use_devices))
269+
threads.append(
270+
WorkerThreadHolder(
271+
thread=thread,
272+
ppn=task.ppn,
273+
load=task.load,
274+
devices=use_devices,
275+
task=task,
276+
start=time.time(),
277+
)
278+
)
125279

126280
# Wait for the last tests to complete (MOVED INSIDE CONTEXT)
127281
while len(threads) != 0:
128282
# Keep track of threads that are done
129-
join_first_dead_thread(progress, complete_tracker)
283+
join_first_dead_thread(progress, complete_tracker, interactive)
284+
285+
# Notify about long-running threads
286+
notify_long_running_threads(progress, running_tracker, interactive)
130287

131288
# Do not overwhelm this core with this loop
132289
time.sleep(0.05)

0 commit comments

Comments
 (0)