Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 146 additions & 4 deletions toolchain/mfc/sched.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High-level Suggestion

Refactor the notification state management by replacing the multiple boolean flags in WorkerThreadHolder with a single state field. This field would track the last notification milestone, simplifying both the data structure and the notification logic. [High-level, importance: 7]

Solution Walkthrough:

Before:

@dataclasses.dataclass
class WorkerThreadHolder:
    # ...
    notified_30s: bool = False
    notified_2m:  bool = False
    notified_10m: bool = False
    notified_30m: bool = False

def notify_long_running_threads(...):
    # ...
    if interactive:
        if elapsed >= threshold and not holder.notified_30s:
            # print notification
            holder.notified_30s = True
    else: # headless
        if not holder.notified_2m and elapsed >= 120:
            holder.notified_2m = True
        if not holder.notified_10m and elapsed >= 600:
            holder.notified_10m = True
        # ... and so on

After:

@dataclasses.dataclass
class WorkerThreadHolder:
    # ...
    # Tracks the index of the last notification sent. -1 means none.
    last_notification_idx: int = -1

def notify_long_running_threads(...):
    # ...
    if interactive:
        if elapsed >= threshold and holder.last_notification_idx == -1:
            # print notification
            holder.last_notification_idx = 0 # Mark as notified
    else: # headless
        for i, (threshold, msg) in enumerate(HEADLESS_THRESHOLDS):
            if elapsed >= threshold and i > holder.last_notification_idx:
                # print notification for this milestone
                holder.last_notification_idx = i

Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,21 @@

from .printer import cons

# Thresholds for long-running test notifications
# Interactive mode: dimension-aware thresholds
INTERACTIVE_THRESHOLDS = {
1: 30.0, # 1D: 30 seconds
2: 60.0, # 2D: 1 minute
3: 120.0, # 3D: 2 minutes
}

# Headless mode: fixed time-based thresholds (regardless of dimensionality)
HEADLESS_THRESHOLDS = (
(2 * 60, "[italic yellow]Still running[/italic yellow] (>2min)"),
(10 * 60, "[italic yellow]Still running[/italic yellow] (>10min)"),
(30 * 60, "[bold red]Still running[/bold red] (>30min, may be hanging)"),
)

class WorkerThread(threading.Thread):
def __init__(self, *args, **kwargs):
self.exc = None
Expand All @@ -29,6 +44,13 @@ class WorkerThreadHolder:
ppn: int
load: float
devices: typing.Set[int]
task: typing.Optional['Task'] = None
start: float = 0.0
# Track which milestones we've already logged
notified_30s: bool = False # for interactive mode
notified_2m: bool = False
notified_10m: bool = False
notified_30m: bool = False


@dataclasses.dataclass
Expand All @@ -44,7 +66,99 @@ def sched(tasks: typing.List[Task], nThreads: int, devices: typing.Set[int] = No

sched.LOAD = { id: 0.0 for id in devices or [] }

def join_first_dead_thread(progress, complete_tracker) -> None:
def get_case_dimensionality(case) -> int:
"""Determine if a test case is 1D, 2D, or 3D based on m, n, p parameters."""
if not hasattr(case, 'params'):
return 1 # Default to 1D if we can't determine

params = case.params
p = params.get('p', 0)
n = params.get('n', 0)

if p != 0:
return 3 # 3D
elif n != 0:
return 2 # 2D
else:
return 1 # 1D

def get_threshold_for_case(case, interactive: bool) -> float:
"""Get the appropriate threshold for a test case."""
if interactive:
dim = get_case_dimensionality(case)
return INTERACTIVE_THRESHOLDS.get(dim, INTERACTIVE_THRESHOLDS[1])
else:
# Headless mode uses fixed thresholds
return HEADLESS_THRESHOLDS[0][0] # 2 minutes

def notify_long_running_threads(progress, running_tracker, interactive: bool) -> None:
now = time.time()
long_running_for_progress = []

for holder in threads:
if not holder.thread.is_alive():
continue

elapsed = now - holder.start
case = holder.task.args[0] if holder.task and holder.task.args else None
case_uuid = case.get_uuid() if hasattr(case, "get_uuid") else "unknown"
case_trace = getattr(case, "trace", "")

# --- interactive: dimension-aware thresholds ---
if interactive:
threshold = get_threshold_for_case(case, interactive=True)

if elapsed >= threshold:
long_running_for_progress.append((case_uuid, case_trace))

# Print explicit line once when crossing threshold
if not holder.notified_30s:
dim = get_case_dimensionality(case)
dim_label = f"{dim}D"
time_label = f"{int(threshold)}s" if threshold < 60 else f"{int(threshold/60)}min"
cons.print(
f" [italic yellow]Still running[/italic yellow] ({dim_label}, >{time_label}) "
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
)
holder.notified_30s = True

# --- headless: milestone notifications at 2, 10, 30 minutes ---
else:
# 2 minutes
if (not holder.notified_2m) and elapsed >= 2 * 60:
cons.print(
f" {HEADLESS_THRESHOLDS[0][1]} "
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
)
holder.notified_2m = True

# 10 minutes
if (not holder.notified_10m) and elapsed >= 10 * 60:
cons.print(
f" {HEADLESS_THRESHOLDS[1][1]} "
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
)
holder.notified_10m = True

# 30 minutes
if (not holder.notified_30m) and elapsed >= 30 * 60:
cons.print(
f" {HEADLESS_THRESHOLDS[2][1]} "
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
)
holder.notified_30m = True

# update the interactive "Running" row
if interactive and running_tracker is not None:
if long_running_for_progress:
summary = ", ".join(uuid for uuid, _ in long_running_for_progress[:5])
if len(long_running_for_progress) > 5:
summary += f", +{len(long_running_for_progress) - 5} more"
progress.update(running_tracker, description=f"Running (long): {summary}")
else:
progress.update(running_tracker, description="Running (long): -")

def join_first_dead_thread(progress, complete_tracker, interactive: bool) -> None:
nonlocal threads, nAvailable

for threadID, threadHolder in enumerate(threads):
Expand Down Expand Up @@ -82,6 +196,17 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
else:
raise threadHolder.thread.exc

# Print completion message for long-running tests in interactive mode
if interactive and threadHolder.notified_30s:
elapsed = time.time() - threadHolder.start
case = threadHolder.task.args[0] if threadHolder.task and threadHolder.task.args else None
case_uuid = case.get_uuid() if hasattr(case, "get_uuid") else "unknown"
case_trace = getattr(case, "trace", "")
cons.print(
f" [italic green]Completed[/italic green] (after {elapsed:.1f}s) "
f"[bold magenta]{case_uuid}[/bold magenta] {case_trace}"
)

nAvailable += threadHolder.ppn
for device in threadHolder.devices or set():
sched.LOAD[device] -= threadHolder.load / threadHolder.ppn
Expand All @@ -93,8 +218,10 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
break

with rich.progress.Progress(console=cons.raw, transient=True) as progress:
interactive = cons.raw.is_terminal
queue_tracker = progress.add_task("Queued ", total=len(tasks))
complete_tracker = progress.add_task("Completed", total=len(tasks))
running_tracker = progress.add_task("Running ", total=None) if interactive else None

# Queue Tests
for task in tasks:
Expand All @@ -106,7 +233,10 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
break

# Keep track of threads that are done
join_first_dead_thread(progress, complete_tracker)
join_first_dead_thread(progress, complete_tracker, interactive)

# Notify about long-running threads
notify_long_running_threads(progress, running_tracker, interactive)

# Do not overwhelm this core with this loop
time.sleep(0.05)
Expand All @@ -128,12 +258,24 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
thread = WorkerThread(target=task.func, args=tuple(task.args) + (use_devices,))
thread.start()

threads.append(WorkerThreadHolder(thread, task.ppn, task.load, use_devices))
threads.append(
WorkerThreadHolder(
thread=thread,
ppn=task.ppn,
load=task.load,
devices=use_devices,
task=task,
start=time.time(),
)
)

# Wait for the last tests to complete (MOVED INSIDE CONTEXT)
while len(threads) != 0:
# Keep track of threads that are done
join_first_dead_thread(progress, complete_tracker)
join_first_dead_thread(progress, complete_tracker, interactive)

# Notify about long-running threads
notify_long_running_threads(progress, running_tracker, interactive)

# Do not overwhelm this core with this loop
time.sleep(0.05)
Expand Down
Loading