Skip to content

Commit 59ff611

Browse files
committed
Make tags per platform
1 parent 09f44d6 commit 59ff611

File tree

1 file changed

+73
-26
lines changed
  • graalpython/com.oracle.graal.python.test/src

1 file changed

+73
-26
lines changed

graalpython/com.oracle.graal.python.test/src/runner.py

Lines changed: 73 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import multiprocessing
4646
import os
4747
import pickle
48+
import platform
4849
import re
4950
import shlex
5051
import signal
@@ -70,6 +71,13 @@
7071
TAGGED_TEST_ROOT = (DIR.parent.parent / 'lib-python' / '3' / 'test').resolve()
7172
IS_GRAALPY = sys.implementation.name == 'graalpy'
7273

74+
PLATFORM_KEYS = {sys.platform, platform.machine(), sys.implementation.name}
75+
if IS_GRAALPY:
76+
# noinspection PyUnresolvedReferences
77+
PLATFORM_KEYS.add('native_image' if __graalpython__.is_native else 'jvm')
78+
79+
CURRENT_PLATFORM_KEYS = frozenset({f'{sys.platform}-{platform.machine()}'})
80+
7381

7482
class Logger:
7583
report_incomplete = sys.stdout.isatty()
@@ -433,16 +441,37 @@ def generate_tags(self, append=False):
433441
by_file[result.test_id.test_file].append(result)
434442
for test_file, results in by_file.items():
435443
test_file = configure_test_file(test_file)
436-
tag_file = test_file.get_tag_file()
437-
if not tag_file:
438-
log(f"WARNNING: no tag directory for test file {test_file}")
439-
continue
440-
tags = {result.test_id.test_name for result in results if result.status == TestStatus.SUCCESS}
441-
if append:
442-
tags |= {test.test_name for test in read_tags(test_file)}
443-
with open(tag_file, 'w') as f:
444-
for test_name in sorted(tags):
445-
f.write(f'{test_name}\n')
444+
new_tags = [
445+
Tag(result.test_id, keys=CURRENT_PLATFORM_KEYS)
446+
for result in results if result.status == TestStatus.SUCCESS
447+
]
448+
tags = merge_tags(read_tags(test_file), new_tags, append=append)
449+
write_tags(test_file, tags)
450+
451+
452+
def merge_tags(original: typing.Iterable['Tag'], new: typing.Iterable['Tag'], append: bool) -> set['Tag']:
453+
original_by_name = {t.test_id.test_name: t.keys for t in original}
454+
merged = {}
455+
for tag in new:
456+
if existing_keys := original_by_name.get(tag.test_id.test_name):
457+
merged[tag.test_id.test_name] = tag.with_merged_keys(existing_keys)
458+
else:
459+
merged[tag.test_id.test_name] = tag
460+
if append:
461+
for test_name, tag in original_by_name.items():
462+
if test_name not in merged:
463+
merged[test_name] = tag
464+
return set(merged.values())
465+
466+
467+
def write_tags(test_file: 'TestFile', tags: typing.Iterable['Tag']):
468+
tag_file = test_file.get_tag_file()
469+
if not tag_file:
470+
log(f"WARNING: no tag directory for test file {test_file}")
471+
return
472+
with open(tag_file, 'w') as f:
473+
for tag in sorted(tags, key=lambda t: t.test_id.test_name):
474+
f.write(f'{tag}\n')
446475

447476

448477
def interrupt_process(process: subprocess.Popen):
@@ -778,6 +807,10 @@ def run_in_subprocess_and_watch(self):
778807
raise RuntimeError("Worker is not making progress")
779808

780809

810+
def platform_keys_match(items: typing.Iterable[str]):
811+
return any(all(key in PLATFORM_KEYS for key in item.split('-')) for item in items)
812+
813+
781814
@dataclass
782815
class TestFileConfig:
783816
serial: bool = False
@@ -787,15 +820,11 @@ class TestFileConfig:
787820

788821
@classmethod
789822
def from_dict(cls, config: dict):
790-
exclude_keys = {sys.platform}
791-
if IS_GRAALPY:
792-
# noinspection PyUnresolvedReferences
793-
exclude_keys.add('native_image' if __graalpython__.is_native else 'jvm')
794823
return cls(
795824
serial=config.get('serial', cls.serial),
796825
partial_splits=config.get('partial_splits_individual_tests', cls.partial_splits),
797826
per_test_timeout=config.get('per_test_timeout', cls.per_test_timeout),
798-
exclude=bool(set(config.get('exclude_on', set())) & exclude_keys),
827+
exclude=platform_keys_match(config.get('exclude_on', ())),
799828
)
800829

801830
def combine(self, other: 'TestFileConfig'):
@@ -925,22 +954,22 @@ def __str__(self):
925954

926955

927956
def filter_tree(test_file: TestFile, test_suite: unittest.TestSuite, specifiers: list[TestSpecifier],
928-
tags: list[TestId] | None):
957+
tagged_ids: list[TestId] | None):
929958
keep_tests = []
930959
collected_tests = []
931960
for test in test_suite:
932961
# When test loading fails, unittest just creates an instance of _FailedTest
933962
if exception := getattr(test, '_exception', None):
934963
raise exception
935964
if hasattr(test, '__iter__'):
936-
sub_collected = filter_tree(test_file, test, specifiers, tags)
965+
sub_collected = filter_tree(test_file, test, specifiers, tagged_ids)
937966
if sub_collected:
938967
keep_tests.append(test)
939968
collected_tests += sub_collected
940969
else:
941970
test_id = TestId.from_test_case(test_file.path, test)
942971
if any(s.match(test_id) for s in specifiers):
943-
if tags is None or test_id in tags:
972+
if tagged_ids is None or test_id in tagged_ids:
944973
keep_tests.append(test)
945974
collected_tests.append(Test(test_id, test_file))
946975
test_suite._tests = keep_tests
@@ -991,18 +1020,18 @@ def collect_module(test_file: TestFile, specifiers: list[TestSpecifier], use_tag
9911020
sys.path.insert(0, str(config.rootdir))
9921021
try:
9931022
loader = TopLevelFunctionLoader() if config.run_top_level_functions else unittest.TestLoader()
994-
tags = None
1023+
tagged_ids = None
9951024
if use_tags and config.tags_dir:
996-
tags = read_tags(test_file)
997-
if not tags:
1025+
tagged_ids = [tag.test_id for tag in read_tags(test_file) if platform_keys_match(tag.keys)]
1026+
if not tagged_ids:
9981027
return None
9991028
test_module = test_path_to_module(test_file)
10001029
try:
10011030
test_suite = loader.loadTestsFromName(test_module)
10021031
except unittest.SkipTest as e:
10031032
log(f"Test file {test_file} skipped: {e}")
10041033
return
1005-
collected_tests = filter_tree(test_file, test_suite, specifiers, tags)
1034+
collected_tests = filter_tree(test_file, test_suite, specifiers, tagged_ids)
10061035
if partial and test_file.test_config.partial_splits:
10071036
selected, total = partial
10081037
collected_tests = collected_tests[selected::total]
@@ -1066,15 +1095,33 @@ def collect(all_specifiers: list[TestSpecifier], *, use_tags=False, ignore=None,
10661095
return to_run
10671096

10681097

1069-
def read_tags(test_file: TestFile) -> list[TestId]:
1098+
@dataclass(frozen=True)
1099+
class Tag:
1100+
test_id: TestId
1101+
keys: frozenset[str]
1102+
1103+
def with_merged_keys(self, keys: typing.AbstractSet[str]) -> 'Tag':
1104+
return Tag(self.test_id, self.keys | keys)
1105+
1106+
def __str__(self):
1107+
return f'{self.test_id.test_name} @ {",".join(sorted(self.keys))}'
1108+
1109+
1110+
def read_tags(test_file: TestFile) -> list[Tag]:
10701111
tag_file = test_file.get_tag_file()
10711112
tags = []
10721113
if tag_file.exists():
10731114
with open(tag_file) as f:
10741115
for line in f:
1075-
test = line.strip()
1076-
tags.append(TestId(test_file.path, test))
1077-
return tags
1116+
test, _, keys = line.partition('@')
1117+
test = test.strip()
1118+
keys = keys.strip()
1119+
if not keys:
1120+
log(f'WARNING: invalid tag {test}: missing platform keys')
1121+
tags.append(Tag(
1122+
TestId(test_file.path, test),
1123+
frozenset(keys.split(',')),
1124+
))
10781125
return tags
10791126

10801127

0 commit comments

Comments
 (0)