Skip to content

Commit 9f33b05

Browse files
committed
Improve parity runner topic controls and always upload artifacts
1 parent 11d0fb3 commit 9f33b05

File tree

2 files changed

+178
-24
lines changed

2 files changed

+178
-24
lines changed

.github/workflows/matlab-parity-gate.yml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,20 @@ jobs:
6262
python3 python/tools/verify_python_vs_matlab_similarity.py --enforce-gate
6363
6464
- name: Freeze similarity baseline
65+
if: always()
6566
run: |
66-
python3 python/tools/freeze_similarity_baseline.py
67+
if [ -f python/reports/python_vs_matlab_similarity_report.json ]; then
68+
python3 python/tools/freeze_similarity_baseline.py
69+
else
70+
echo "Similarity report missing; skipping baseline freeze."
71+
fi
6772
6873
- name: Upload parity reports
74+
if: always()
6975
uses: actions/upload-artifact@v4
7076
with:
7177
name: nstat-matlab-parity-reports
78+
if-no-files-found: warn
7279
path: |
7380
python/reports/python_vs_matlab_similarity_report.json
7481
python/reports/python_vs_matlab_similarity_baseline.json

python/tools/verify_python_vs_matlab_similarity.py

Lines changed: 170 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import Any
1818

1919
REPO_ROOT = Path(__file__).resolve().parents[2]
20-
REPORT_DIR = REPO_ROOT / "python" / "reports"
2120
MATLAB_BIN = Path("/Applications/MATLAB_R2025b.app/bin/matlab")
2221
MATLAB_EXTRA_ARGS = [arg for arg in os.environ.get("NSTAT_MATLAB_EXTRA_ARGS", "").split() if arg]
2322
FORCE_M_HELP_SCRIPTS = os.environ.get("NSTAT_FORCE_M_HELP_SCRIPTS", "").strip().lower() in {"1", "true", "yes", "on"}
@@ -65,6 +64,21 @@
6564
"StimulusDecode2D",
6665
"nSTATPaperExamples",
6766
}
67+
DEFAULT_HELP_TOPIC_TIMEOUT_S = 120
68+
DEFAULT_TOPIC_TIMEOUT_OVERRIDES: dict[str, int] = {
69+
"SignalObjExamples": 180,
70+
"CovariateExamples": 180,
71+
"CovCollExamples": 180,
72+
"nSpikeTrainExamples": 180,
73+
"nstCollExamples": 180,
74+
"EventsExamples": 180,
75+
"HistoryExamples": 180,
76+
"TrialExamples": 180,
77+
"AnalysisExamples": 180,
78+
"DecodingExampleWithHist": 360,
79+
"StimulusDecode2D": 180,
80+
"nSTATPaperExamples": 240,
81+
}
6882

6983

7084
def _matlab_batch_command(batch_cmd: str) -> list[str]:
@@ -271,6 +285,49 @@ def _example_topics() -> list[tuple[str, str]]:
271285
return out
272286

273287

288+
def _parse_topics_arg(topics_arg: list[str] | None) -> set[str] | None:
289+
if not topics_arg:
290+
return None
291+
topics: set[str] = set()
292+
for raw in topics_arg:
293+
for part in raw.split(","):
294+
stem = part.strip()
295+
if stem:
296+
topics.add(stem)
297+
return topics or None
298+
299+
300+
def _parse_topic_timeout_overrides(specs: list[str]) -> dict[str, int]:
301+
out: dict[str, int] = {}
302+
for spec in specs:
303+
key, sep, value = spec.partition("=")
304+
topic = key.strip()
305+
raw_seconds = value.strip()
306+
if sep != "=" or not topic or not raw_seconds:
307+
raise ValueError(f"invalid --topic-timeout '{spec}'; expected TOPIC=SECONDS")
308+
try:
309+
seconds = int(raw_seconds)
310+
except ValueError as exc:
311+
raise ValueError(f"invalid timeout value in '{spec}': {raw_seconds}") from exc
312+
if seconds <= 0:
313+
raise ValueError(f"timeout must be positive in '{spec}'")
314+
out[topic] = seconds
315+
return out
316+
317+
318+
def _resolve_topics(requested_topics: set[str] | None) -> list[tuple[str, str]]:
319+
topics = _example_topics()
320+
if requested_topics is None:
321+
return topics
322+
323+
available = {Path(target).stem for _, target in topics}
324+
missing = sorted(requested_topics - available)
325+
if missing:
326+
raise ValueError(f"unknown topic(s): {missing}")
327+
328+
return [(title, target) for title, target in topics if Path(target).stem in requested_topics]
329+
330+
274331
def _run_python_topic(stem: str) -> dict[str, Any]:
275332
try:
276333
mod = importlib.import_module(f"examples.help_topics.{stem}")
@@ -429,9 +486,12 @@ def _compare_topic_scalars(py_scalars: dict[str, float], ml_scalars: dict[str, f
429486
}
430487

431488

432-
def _help_similarity() -> dict[str, Any]:
489+
def _help_similarity(
490+
topics: list[tuple[str, str]],
491+
default_timeout_s: int = DEFAULT_HELP_TOPIC_TIMEOUT_S,
492+
topic_timeout_overrides: dict[str, int] | None = None,
493+
) -> dict[str, Any]:
433494
rows: list[dict[str, Any]] = []
434-
topics = _example_topics()
435495

436496
summary = {
437497
"total_topics": len(topics),
@@ -444,10 +504,9 @@ def _help_similarity() -> dict[str, Any]:
444504
}
445505

446506
scores: list[float] = []
447-
topic_timeouts = {
448-
"DecodingExampleWithHist": 240,
449-
"nSTATPaperExamples": 240,
450-
}
507+
topic_timeouts = dict(DEFAULT_TOPIC_TIMEOUT_OVERRIDES)
508+
if topic_timeout_overrides:
509+
topic_timeouts.update(topic_timeout_overrides)
451510
for idx, (title, target) in enumerate(topics, start=1):
452511
stem = Path(target).stem
453512
m_rel = f"helpfiles/{stem}.m"
@@ -464,7 +523,7 @@ def _help_similarity() -> dict[str, Any]:
464523
print(f"[help {idx}/{len(topics)}] {stem}", flush=True)
465524

466525
py = _run_python_topic(stem)
467-
timeout_s = topic_timeouts.get(stem, 120)
526+
timeout_s = topic_timeouts.get(stem, default_timeout_s)
468527
ml = _run_matlab_help_script(script_rel, timeout_s=timeout_s)
469528

470529
if py.get("ok"):
@@ -512,6 +571,7 @@ def _help_similarity() -> dict[str, Any]:
512571
"matlab_script_used": ml.get("script_used", script_rel),
513572
"matlab_fallback_script_used": ml.get("fallback_script_used", ""),
514573
"matlab_runtime_s": ml.get("runtime_s"),
574+
"matlab_timeout_s": timeout_s,
515575
"scalar_overlap": scalar_cmp,
516576
"similarity_score": score,
517577
}
@@ -521,12 +581,20 @@ def _help_similarity() -> dict[str, Any]:
521581
return {"summary": summary, "rows": rows}
522582

523583

524-
def _evaluate_parity_contract(help_rows: list[dict[str, Any]]) -> dict[str, Any]:
584+
def _evaluate_parity_contract(help_rows: list[dict[str, Any]], topics_filter: set[str] | None = None) -> dict[str, Any]:
525585
by_topic = {str(r.get("topic", "")): r for r in help_rows}
526586
rows: list[dict[str, Any]] = []
527587
failures: list[str] = []
588+
if topics_filter is None:
589+
contract_items = list(PARITY_CONTRACT.items())
590+
else:
591+
contract_items = [(topic, required_keys) for topic, required_keys in PARITY_CONTRACT.items() if topic in topics_filter]
592+
missing_contract_entries = sorted(topics_filter - set(PARITY_CONTRACT))
593+
for topic in missing_contract_entries:
594+
failures.append(f"{topic}: missing parity contract entry")
595+
rows.append({"topic": topic, "required_keys": [], "status": "missing_contract"})
528596

529-
for topic, required_keys in PARITY_CONTRACT.items():
597+
for topic, required_keys in contract_items:
530598
row = by_topic.get(topic)
531599
if row is None:
532600
failures.append(f"{topic}: missing topic row")
@@ -586,12 +654,18 @@ def _evaluate_parity_contract(help_rows: list[dict[str, Any]]) -> dict[str, Any]
586654

587655

588656
def _evaluate_regression_gate(report: dict[str, Any]) -> dict[str, Any]:
657+
topic_selection = report.get("topic_selection", {})
589658
class_summary = report.get("class_similarity", {}).get("summary", {})
590659
help_summary = report.get("helpfile_similarity", {}).get("summary", {})
591660
help_rows = report.get("helpfile_similarity", {}).get("rows", [])
592661
parity_contract = report.get("parity_contract", {})
593662

594663
failures: list[str] = []
664+
full_suite = bool(topic_selection.get("full_suite", True))
665+
selected_topics = int(topic_selection.get("total_topics", help_summary.get("total_topics", 0)))
666+
python_required = HELP_PYTHON_REQUIRED_OK if full_suite else selected_topics
667+
matlab_required = HELP_MATLAB_MIN_OK if full_suite else selected_topics
668+
scalar_required = SCALAR_OVERLAP_PASS_MIN_TOPICS if full_suite else selected_topics
595669

596670
class_passed = int(class_summary.get("passed", 0))
597671
class_total = int(class_summary.get("total", 0))
@@ -602,18 +676,34 @@ def _evaluate_regression_gate(report: dict[str, Any]) -> dict[str, Any]:
602676

603677
python_ok = int(help_summary.get("python_ok", 0))
604678
total_topics = int(help_summary.get("total_topics", 0))
605-
if python_ok < HELP_PYTHON_REQUIRED_OK or python_ok != total_topics:
606-
failures.append(f"python help gate failed: expected all topics ok, got {python_ok}/{total_topics}")
679+
if python_ok < python_required or python_ok != total_topics:
680+
if full_suite:
681+
failures.append(f"python help gate failed: expected all topics ok, got {python_ok}/{total_topics}")
682+
else:
683+
failures.append(
684+
f"python help gate failed for selected topics: expected {python_required}/{selected_topics}, "
685+
f"got {python_ok}/{total_topics}"
686+
)
607687

608688
matlab_ok = int(help_summary.get("matlab_ok", 0))
609-
if matlab_ok < HELP_MATLAB_MIN_OK:
610-
failures.append(f"matlab help gate failed: minimum {HELP_MATLAB_MIN_OK}, got {matlab_ok}")
689+
if matlab_ok < matlab_required:
690+
if full_suite:
691+
failures.append(f"matlab help gate failed: minimum {HELP_MATLAB_MIN_OK}, got {matlab_ok}")
692+
else:
693+
failures.append(
694+
f"matlab help gate failed for selected topics: minimum {matlab_required}, got {matlab_ok}"
695+
)
611696

612697
scalar_overlap_pass_topics = int(help_summary.get("scalar_overlap_pass_topics", 0))
613-
if scalar_overlap_pass_topics < SCALAR_OVERLAP_PASS_MIN_TOPICS:
614-
failures.append(
615-
f"scalar overlap gate failed: minimum {SCALAR_OVERLAP_PASS_MIN_TOPICS}, got {scalar_overlap_pass_topics}"
616-
)
698+
if scalar_overlap_pass_topics < scalar_required:
699+
if full_suite:
700+
failures.append(
701+
f"scalar overlap gate failed: minimum {SCALAR_OVERLAP_PASS_MIN_TOPICS}, got {scalar_overlap_pass_topics}"
702+
)
703+
else:
704+
failures.append(
705+
f"scalar overlap gate failed for selected topics: minimum {scalar_required}, got {scalar_overlap_pass_topics}"
706+
)
617707

618708
matlab_failed_topics = sorted([str(r.get("topic", "")) for r in help_rows if not bool(r.get("matlab_ok"))])
619709
unexpected_failures = sorted(set(matlab_failed_topics) - KNOWN_MATLAB_HELP_FAILURES)
@@ -642,12 +732,57 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
642732
action="store_true",
643733
help="Return non-zero exit code if regression gate fails.",
644734
)
735+
parser.add_argument(
736+
"--topics",
737+
nargs="+",
738+
default=None,
739+
help="Optional help-topic stems to run (space/comma separated). Default is all topics.",
740+
)
741+
parser.add_argument(
742+
"--default-topic-timeout",
743+
type=int,
744+
default=DEFAULT_HELP_TOPIC_TIMEOUT_S,
745+
help=f"Default MATLAB timeout per topic in seconds (default: {DEFAULT_HELP_TOPIC_TIMEOUT_S}).",
746+
)
747+
parser.add_argument(
748+
"--topic-timeout",
749+
action="append",
750+
default=[],
751+
help="Override per-topic MATLAB timeout using TOPIC=SECONDS (repeatable).",
752+
)
753+
parser.add_argument(
754+
"--report-path",
755+
default="python/reports/python_vs_matlab_similarity_report.json",
756+
help="Output report path (absolute or repo-relative).",
757+
)
645758
return parser.parse_args(argv)
646759

647760

648761
def main(argv: list[str] | None = None) -> int:
649762
args = _parse_args(argv)
650763
report: dict[str, Any] = {}
764+
if args.default_topic_timeout <= 0:
765+
print("--default-topic-timeout must be positive", file=sys.stderr)
766+
return 2
767+
try:
768+
requested_topics = _parse_topics_arg(args.topics)
769+
topics = _resolve_topics(requested_topics)
770+
topic_timeout_overrides = _parse_topic_timeout_overrides(args.topic_timeout)
771+
except ValueError as exc:
772+
print(str(exc), file=sys.stderr)
773+
return 2
774+
775+
selected_topic_stems = [Path(target).stem for _, target in topics]
776+
full_suite = requested_topics is None
777+
report["topic_selection"] = {
778+
"full_suite": full_suite,
779+
"requested_topics": sorted(requested_topics) if requested_topics else [],
780+
"selected_topics": selected_topic_stems,
781+
"total_topics": len(selected_topic_stems),
782+
"default_timeout_s": args.default_topic_timeout,
783+
"topic_timeout_overrides": topic_timeout_overrides,
784+
"force_m_help_scripts": FORCE_M_HELP_SCRIPTS,
785+
}
651786

652787
print("[class] running Python/MATLAB class checks", flush=True)
653788
py_cls = _python_class_checks()
@@ -667,16 +802,28 @@ def main(argv: list[str] | None = None) -> int:
667802
"comparisons": [],
668803
}
669804

670-
report["helpfile_similarity"] = _help_similarity()
671-
report["parity_contract"] = _evaluate_parity_contract(report["helpfile_similarity"]["rows"])
805+
report["helpfile_similarity"] = _help_similarity(
806+
topics=topics,
807+
default_timeout_s=args.default_topic_timeout,
808+
topic_timeout_overrides=topic_timeout_overrides,
809+
)
810+
contract_topics = None if full_suite else set(selected_topic_stems)
811+
report["parity_contract"] = _evaluate_parity_contract(report["helpfile_similarity"]["rows"], topics_filter=contract_topics)
672812
report["regression_gate"] = _evaluate_regression_gate(report)
673813

674-
REPORT_DIR.mkdir(parents=True, exist_ok=True)
675-
out = REPORT_DIR / "python_vs_matlab_similarity_report.json"
814+
out = Path(args.report_path)
815+
if not out.is_absolute():
816+
out = REPO_ROOT / out
817+
out.parent.mkdir(parents=True, exist_ok=True)
676818
out.write_text(json.dumps(report, indent=2), encoding="utf-8")
819+
try:
820+
out_print = str(out.relative_to(REPO_ROOT))
821+
except ValueError:
822+
out_print = str(out)
677823

678824
printable = {
679-
"report": str(out.relative_to(REPO_ROOT)),
825+
"report": out_print,
826+
"topic_selection": report["topic_selection"],
680827
"class_similarity": report["class_similarity"]["summary"],
681828
"helpfile_similarity": report["helpfile_similarity"]["summary"],
682829
"parity_contract": report["parity_contract"],

0 commit comments

Comments
 (0)