Skip to content

Commit 1d3d418

Browse files
authored
Merge pull request #3255 from Flamefire/20240313110424_new_pr_pytorch
Explicitely mention that the PyTorch easyblock needs updating when failing for this reason
2 parents 8b51ccf + ad19427 commit 1d3d418

File tree

1 file changed

+132
-63
lines changed

1 file changed

+132
-63
lines changed

easybuild/easyblocks/p/pytorch.py

Lines changed: 132 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,41 @@
4444

4545

4646
if sys.version_info >= (3, 9):
47-
from typing import NamedTuple
48-
FailedTestNames = NamedTuple('FailedTestNames', [('error', list[str]), ('fail', list[str])])
49-
TestSuiteResult = NamedTuple('TestSuiteResult', [('name', str), ('summary', str)])
50-
TestResult = NamedTuple('TestResult', [('test_cnt', int),
51-
('error_cnt', int),
52-
('failure_cnt', int),
53-
('failed_suites', list[TestSuiteResult])])
47+
from dataclasses import dataclass
48+
49+
@dataclass
50+
class FailedTestNames:
51+
"""Hold list of tests names that failed with error or failure"""
52+
error: list[str]
53+
fail: list[str]
54+
55+
@dataclass
56+
class TestSuiteResult:
57+
"""Hold the name of a test suite and a summary of the failures"""
58+
name: str
59+
summary: str
60+
61+
@dataclass
62+
class TestResult:
63+
"""Status report and results of a test run"""
64+
test_cnt: int
65+
error_cnt: int
66+
failure_cnt: int
67+
failed_suites: list[TestSuiteResult]
68+
terminated_suites: dict[str, str] # Name and signal of terminated suites
69+
all_failed_suites: set[str] # Names of all failed suites
5470
else:
5571
from collections import namedtuple
5672
FailedTestNames = namedtuple('FailedTestNames', ('error', 'fail'))
5773
TestSuiteResult = namedtuple('TestSuiteResult', ('name', 'summary'))
58-
TestResult = namedtuple('TestResult', ('test_cnt', 'error_cnt', 'failure_cnt', 'failed_suites'))
74+
TerminatedTestSuite = namedtuple('TerminatedTestSuite', ('name', 'signal'))
75+
TestResult = namedtuple('TestResult', ('test_cnt',
76+
'error_cnt',
77+
'failure_cnt',
78+
'failed_suites',
79+
'terminated_suites',
80+
'all_failed_suites'
81+
))
5982

6083

6184
def find_failed_test_names(tests_out):
@@ -198,7 +221,17 @@ def get_count_for_pattern(regex, text):
198221
for m in re.finditer(regex, tests_out, re.M):
199222
test_cnt += sum(get_count_for_pattern(p, m.group("summary")) for p in count_patterns)
200223

201-
return TestResult(test_cnt=test_cnt, error_cnt=error_cnt, failure_cnt=failure_cnt, failed_suites=failed_suites)
224+
# Gather all failed tests suites in case we missed any,
225+
# e.g. when it exited due to syntax errors or with a signal such as SIGSEGV
226+
failed_suites_and_signal = set(
227+
re.findall(r"^(?P<test_name>.*) failed!(?: Received signal: (\w+))?\s*$", tests_out, re.M)
228+
)
229+
230+
return TestResult(test_cnt=test_cnt, error_cnt=error_cnt, failure_cnt=failure_cnt,
231+
failed_suites=failed_suites,
232+
# Assumes that the suite name is unique
233+
terminated_suites={name: signal for name, signal in failed_suites_and_signal if signal},
234+
all_failed_suites={i[0] for i in failed_suites_and_signal})
202235

203236

204237
class EB_PyTorch(PythonPackage):
@@ -462,17 +495,17 @@ def test_step(self):
462495
'excluded_tests': ' '.join(excluded_tests)
463496
})
464497

465-
test_result = super(EB_PyTorch, self).test_step(return_output_ec=True)
466-
if test_result is None:
498+
parsed_test_result = super(EB_PyTorch, self).test_step(return_output_ec=True)
499+
if parsed_test_result is None:
467500
if self.cfg['runtest'] is False:
468501
msg = "Do not set 'runtest' to False, use --skip-test-step instead."
469502
else:
470503
msg = "Tests did not run. Make sure 'runtest' is set to a command."
471504
raise EasyBuildError(msg)
472505

473-
tests_out, tests_ec = test_result
506+
tests_out, tests_ec = parsed_test_result
474507

475-
# Show failed subtests to aid in debugging failures
508+
# Show failed subtests, if any, to aid in debugging failures
476509
failed_test_names = find_failed_test_names(tests_out)
477510
if failed_test_names.error or failed_test_names.fail:
478511
msg = []
@@ -485,60 +518,92 @@ def test_step(self):
485518
self.log.warning("\n".join(msg))
486519

487520
# Create clear summary report
488-
test_result = parse_test_log(tests_out)
489-
failure_report = ['%s (%s)' % (suite.name, suite.summary) for suite in test_result.failed_suites]
490-
failed_test_suites = set(suite.name for suite in test_result.failed_suites)
491-
# Gather all failed tests suites in case we missed any (e.g. when it exited due to syntax errors)
492-
# Also unique to be able to compare the lists below
493-
all_failed_test_suites = set(
494-
re.findall(r"^(?P<test_name>.*) failed!(?: Received signal: \w+)?\s*$", tests_out, re.M)
495-
)
496-
# If we missed any test suites prepend a list of all failed test suites
521+
parsed_test_result = parse_test_log(tests_out)
522+
# Use a list of messages we can later join together
523+
failure_msgs = ['\t%s (%s)' % (suite.name, suite.summary) for suite in parsed_test_result.failed_suites]
524+
# These were accounted for
525+
failed_test_suites = set(suite.name for suite in parsed_test_result.failed_suites)
526+
# Those are all that failed according to the summary output
527+
all_failed_test_suites = parsed_test_result.all_failed_suites
528+
# We should have determined all failed test suites and only those.
529+
# Otherwise show the mismatch and terminate later
497530
if failed_test_suites != all_failed_test_suites:
498-
failure_report = ['Failed tests (suites/files):'] + failure_report
531+
failure_msgs.insert(0, 'Failed tests (suites/files):')
499532
# Test suites where we didn't match a specific regexp and hence likely didn't count the failures
500-
failure_report.extend('+ %s' % t for t in sorted(all_failed_test_suites - failed_test_suites))
533+
uncounted_test_suites = all_failed_test_suites - failed_test_suites
534+
if uncounted_test_suites:
535+
failure_msgs.append('Could not count failed tests for the following test suites/files:')
536+
for suite_name in sorted(uncounted_test_suites):
537+
try:
538+
signal = parsed_test_result.terminated_suites[suite_name]
539+
reason = f'Terminated with {signal}'
540+
except KeyError:
541+
# Not ended with signal, might have failed due to e.g. syntax errors
542+
reason = 'Undetected or did not run properly'
543+
failure_msgs.append(f'\t{suite_name} ({reason})')
501544
# Test suites not included in the catch-all regexp but counted. Should be empty.
502-
failure_report.extend('? %s' % t for t in sorted(failed_test_suites - all_failed_test_suites))
503-
504-
failure_report = '\n'.join(failure_report)
545+
unexpected_test_suites = failed_test_suites - all_failed_test_suites
546+
if unexpected_test_suites:
547+
failure_msgs.append('Counted failures of tests from the following test suites/files that are not '
548+
'contained in the summary output of PyTorch:')
549+
failure_msgs.extend(sorted(unexpected_test_suites))
505550

506551
# Calculate total number of unsuccesful and total tests
507-
failed_test_cnt = test_result.failure_cnt + test_result.error_cnt
552+
failed_test_cnt = parsed_test_result.failure_cnt + parsed_test_result.error_cnt
553+
# Only add count message if we detected any failed tests
554+
if failed_test_cnt > 0:
555+
failure_or_failures = 'failure' if parsed_test_result.failure_cnt == 1 else 'failures'
556+
error_or_errors = 'error' if parsed_test_result.error_cnt == 1 else 'errors'
557+
failure_msgs.insert(0, "%d test %s, %d test %s (out of %d):" % (
558+
parsed_test_result.failure_cnt, failure_or_failures,
559+
parsed_test_result.error_cnt, error_or_errors,
560+
parsed_test_result.test_cnt
561+
))
562+
563+
# Assemble final report
564+
failure_report = '\n'.join(failure_msgs)
565+
566+
if failed_test_suites != all_failed_test_suites:
567+
# Fail because we can't be sure how many tests failed
568+
# so comparing to max_failed_tests cannot reasonably be done
569+
if failed_test_suites | set(parsed_test_result.terminated_suites) == all_failed_test_suites:
570+
# All failed test suites are either counted or terminated with a signal
571+
msg = ('Failing because these test suites were terminated which makes it impossible'
572+
'to accurately count the failed tests: ')
573+
msg += ", ".join("%s(%s)" % name_signal
574+
for name_signal in sorted(parsed_test_result.terminated_suites.items()))
575+
elif len(failed_test_suites) < len(all_failed_test_suites):
576+
msg = ('Failing because not all failed tests could be determined. '
577+
'Tests failed to start or the test accounting in the PyTorch EasyBlock needs updating!\n'
578+
'Missing: ' + ', '.join(sorted(all_failed_test_suites - failed_test_suites)))
579+
else:
580+
msg = ('Failing because there were unexpected failures detected: ' +
581+
', '.join(sorted(failed_test_suites - all_failed_test_suites)))
582+
raise EasyBuildError(msg + '\n' +
583+
'You can check the test failures (in the log) manually and if they are harmless, '
584+
'use --ignore-test-failures to make the test step pass.\n' + failure_report)
508585

509586
if failed_test_cnt > 0:
510587
max_failed_tests = self.cfg['max_failed_tests']
511588

512-
failure_or_failures = 'failure' if test_result.failure_cnt == 1 else 'failures'
513-
error_or_errors = 'error' if test_result.error_cnt == 1 else 'errors'
514-
msg = "%d test %s, %d test %s (out of %d):\n" % (
515-
test_result.failure_cnt, failure_or_failures,
516-
test_result.error_cnt, error_or_errors,
517-
test_result.test_cnt
518-
)
519-
msg += failure_report
520-
521-
# If no tests are supposed to fail or some failed for which we were not able to count errors fail now
522-
if max_failed_tests == 0 or failed_test_suites != all_failed_test_suites:
523-
raise EasyBuildError(msg)
524-
else:
525-
msg += '\n\n' + ' '.join([
526-
"The PyTorch test suite is known to include some flaky tests,",
527-
"which may fail depending on the specifics of the system or the context in which they are run.",
528-
"For this PyTorch installation, EasyBuild allows up to %d tests to fail." % max_failed_tests,
529-
"We recommend to double check that the failing tests listed above ",
530-
"are known to be flaky, or do not affect your intended usage of PyTorch.",
531-
"In case of doubt, reach out to the EasyBuild community (via GitHub, Slack, or mailing list).",
532-
])
533-
# Print to console, the user should really be aware that we are accepting failing tests here...
534-
print_warning(msg)
535-
536-
# Also log this warning in the file log
537-
self.log.warning(msg)
538-
539-
if failed_test_cnt > max_failed_tests:
540-
raise EasyBuildError("Too many failed tests (%d), maximum allowed is %d",
541-
failed_test_cnt, max_failed_tests)
589+
# If no tests are supposed to fail don't print the explanation, just fail
590+
if max_failed_tests == 0:
591+
raise EasyBuildError(failure_report)
592+
msg = failure_report + '\n\n' + ' '.join([
593+
"The PyTorch test suite is known to include some flaky tests,",
594+
"which may fail depending on the specifics of the system or the context in which they are run.",
595+
"For this PyTorch installation, EasyBuild allows up to %d tests to fail." % max_failed_tests,
596+
"We recommend to double check that the failing tests listed above ",
597+
"are known to be flaky, or do not affect your intended usage of PyTorch.",
598+
"In case of doubt, reach out to the EasyBuild community (via GitHub, Slack, or mailing list).",
599+
])
600+
# Print to console in addition to file,
601+
# the user should really be aware that we are accepting failing tests here...
602+
print_warning(msg, log=self.log)
603+
604+
if failed_test_cnt > max_failed_tests:
605+
raise EasyBuildError("Too many failed tests (%d), maximum allowed is %d",
606+
failed_test_cnt, max_failed_tests)
542607
elif failure_report:
543608
raise EasyBuildError("Test ended with failures! Exit code: %s\n%s", tests_ec, failure_report)
544609
elif tests_ec:
@@ -576,11 +641,11 @@ def make_module_req_guess(self):
576641
return guesses
577642

578643

579-
if __name__ == '__main__':
580-
arg = sys.argv[1]
581-
if not os.path.isfile(arg):
582-
raise RuntimeError('Expected a test result file to parse, got: ' + arg)
583-
with open(arg, 'r') as f:
644+
def parse_logfile(file):
645+
"""Parse the EB logfile and print the failed tests"""
646+
if not os.path.isfile(file):
647+
raise RuntimeError('Expected a test result file to parse, got: ' + file)
648+
with open(file, 'r') as f:
584649
content = f.read()
585650
m = re.search(r'cmd .*python[^ ]* run_test\.py .* exited with exit code.*output', content)
586651
if m:
@@ -592,3 +657,7 @@ def make_module_req_guess(self):
592657

593658
print("Failed test names: ", find_failed_test_names(content))
594659
print("Test result: ", parse_test_log(content))
660+
661+
662+
if __name__ == '__main__':
663+
parse_logfile(sys.argv[1])

0 commit comments

Comments
 (0)