45
45
import multiprocessing
46
46
import os
47
47
import pickle
48
+ import platform
48
49
import re
49
50
import shlex
50
51
import signal
70
71
TAGGED_TEST_ROOT = (DIR .parent .parent / 'lib-python' / '3' / 'test' ).resolve ()
71
72
IS_GRAALPY = sys .implementation .name == 'graalpy'
72
73
74
+ PLATFORM_KEYS = {sys .platform , platform .machine (), sys .implementation .name }
75
+ if IS_GRAALPY :
76
+ # noinspection PyUnresolvedReferences
77
+ PLATFORM_KEYS .add ('native_image' if __graalpython__ .is_native else 'jvm' )
78
+
79
+ CURRENT_PLATFORM_KEYS = frozenset ({f'{ sys .platform } -{ platform .machine ()} ' })
80
+
73
81
74
82
class Logger :
75
83
report_incomplete = sys .stdout .isatty ()
@@ -433,16 +441,37 @@ def generate_tags(self, append=False):
433
441
by_file [result .test_id .test_file ].append (result )
434
442
for test_file , results in by_file .items ():
435
443
test_file = configure_test_file (test_file )
436
- tag_file = test_file .get_tag_file ()
437
- if not tag_file :
438
- log (f"WARNNING: no tag directory for test file { test_file } " )
439
- continue
440
- tags = {result .test_id .test_name for result in results if result .status == TestStatus .SUCCESS }
441
- if append :
442
- tags |= {test .test_name for test in read_tags (test_file )}
443
- with open (tag_file , 'w' ) as f :
444
- for test_name in sorted (tags ):
445
- f .write (f'{ test_name } \n ' )
444
+ new_tags = [
445
+ Tag (result .test_id , keys = CURRENT_PLATFORM_KEYS )
446
+ for result in results if result .status == TestStatus .SUCCESS
447
+ ]
448
+ tags = merge_tags (read_tags (test_file ), new_tags , append = append )
449
+ write_tags (test_file , tags )
450
+
451
+
452
+ def merge_tags (original : typing .Iterable ['Tag' ], new : typing .Iterable ['Tag' ], append : bool ) -> set ['Tag' ]:
453
+ original_by_name = {t .test_id .test_name : t .keys for t in original }
454
+ merged = {}
455
+ for tag in new :
456
+ if existing_keys := original_by_name .get (tag .test_id .test_name ):
457
+ merged [tag .test_id .test_name ] = tag .with_merged_keys (existing_keys )
458
+ else :
459
+ merged [tag .test_id .test_name ] = tag
460
+ if append :
461
+ for test_name , tag in original_by_name .items ():
462
+ if test_name not in merged :
463
+ merged [test_name ] = tag
464
+ return set (merged .values ())
465
+
466
+
467
+ def write_tags (test_file : 'TestFile' , tags : typing .Iterable ['Tag' ]):
468
+ tag_file = test_file .get_tag_file ()
469
+ if not tag_file :
470
+ log (f"WARNING: no tag directory for test file { test_file } " )
471
+ return
472
+ with open (tag_file , 'w' ) as f :
473
+ for tag in sorted (tags , key = lambda t : t .test_id .test_name ):
474
+ f .write (f'{ tag } \n ' )
446
475
447
476
448
477
def interrupt_process (process : subprocess .Popen ):
@@ -778,6 +807,10 @@ def run_in_subprocess_and_watch(self):
778
807
raise RuntimeError ("Worker is not making progress" )
779
808
780
809
810
+ def platform_keys_match (items : typing .Iterable [str ]):
811
+ return any (all (key in PLATFORM_KEYS for key in item .split ('-' )) for item in items )
812
+
813
+
781
814
@dataclass
782
815
class TestFileConfig :
783
816
serial : bool = False
@@ -787,15 +820,11 @@ class TestFileConfig:
787
820
788
821
@classmethod
789
822
def from_dict (cls , config : dict ):
790
- exclude_keys = {sys .platform }
791
- if IS_GRAALPY :
792
- # noinspection PyUnresolvedReferences
793
- exclude_keys .add ('native_image' if __graalpython__ .is_native else 'jvm' )
794
823
return cls (
795
824
serial = config .get ('serial' , cls .serial ),
796
825
partial_splits = config .get ('partial_splits_individual_tests' , cls .partial_splits ),
797
826
per_test_timeout = config .get ('per_test_timeout' , cls .per_test_timeout ),
798
- exclude = bool ( set ( config .get ('exclude_on' , set ())) & exclude_keys ),
827
+ exclude = platform_keys_match ( config .get ('exclude_on' , ())),
799
828
)
800
829
801
830
def combine (self , other : 'TestFileConfig' ):
@@ -925,22 +954,22 @@ def __str__(self):
925
954
926
955
927
956
def filter_tree (test_file : TestFile , test_suite : unittest .TestSuite , specifiers : list [TestSpecifier ],
928
- tags : list [TestId ] | None ):
957
+ tagged_ids : list [TestId ] | None ):
929
958
keep_tests = []
930
959
collected_tests = []
931
960
for test in test_suite :
932
961
# When test loading fails, unittest just creates an instance of _FailedTest
933
962
if exception := getattr (test , '_exception' , None ):
934
963
raise exception
935
964
if hasattr (test , '__iter__' ):
936
- sub_collected = filter_tree (test_file , test , specifiers , tags )
965
+ sub_collected = filter_tree (test_file , test , specifiers , tagged_ids )
937
966
if sub_collected :
938
967
keep_tests .append (test )
939
968
collected_tests += sub_collected
940
969
else :
941
970
test_id = TestId .from_test_case (test_file .path , test )
942
971
if any (s .match (test_id ) for s in specifiers ):
943
- if tags is None or test_id in tags :
972
+ if tagged_ids is None or test_id in tagged_ids :
944
973
keep_tests .append (test )
945
974
collected_tests .append (Test (test_id , test_file ))
946
975
test_suite ._tests = keep_tests
@@ -991,18 +1020,18 @@ def collect_module(test_file: TestFile, specifiers: list[TestSpecifier], use_tag
991
1020
sys .path .insert (0 , str (config .rootdir ))
992
1021
try :
993
1022
loader = TopLevelFunctionLoader () if config .run_top_level_functions else unittest .TestLoader ()
994
- tags = None
1023
+ tagged_ids = None
995
1024
if use_tags and config .tags_dir :
996
- tags = read_tags (test_file )
997
- if not tags :
1025
+ tagged_ids = [ tag . test_id for tag in read_tags (test_file ) if platform_keys_match ( tag . keys )]
1026
+ if not tagged_ids :
998
1027
return None
999
1028
test_module = test_path_to_module (test_file )
1000
1029
try :
1001
1030
test_suite = loader .loadTestsFromName (test_module )
1002
1031
except unittest .SkipTest as e :
1003
1032
log (f"Test file { test_file } skipped: { e } " )
1004
1033
return
1005
- collected_tests = filter_tree (test_file , test_suite , specifiers , tags )
1034
+ collected_tests = filter_tree (test_file , test_suite , specifiers , tagged_ids )
1006
1035
if partial and test_file .test_config .partial_splits :
1007
1036
selected , total = partial
1008
1037
collected_tests = collected_tests [selected ::total ]
@@ -1066,15 +1095,33 @@ def collect(all_specifiers: list[TestSpecifier], *, use_tags=False, ignore=None,
1066
1095
return to_run
1067
1096
1068
1097
1069
- def read_tags (test_file : TestFile ) -> list [TestId ]:
1098
+ @dataclass (frozen = True )
1099
+ class Tag :
1100
+ test_id : TestId
1101
+ keys : frozenset [str ]
1102
+
1103
+ def with_merged_keys (self , keys : typing .AbstractSet [str ]) -> 'Tag' :
1104
+ return Tag (self .test_id , self .keys | keys )
1105
+
1106
+ def __str__ (self ):
1107
+ return f'{ self .test_id .test_name } @ { "," .join (sorted (self .keys ))} '
1108
+
1109
+
1110
+ def read_tags (test_file : TestFile ) -> list [Tag ]:
1070
1111
tag_file = test_file .get_tag_file ()
1071
1112
tags = []
1072
1113
if tag_file .exists ():
1073
1114
with open (tag_file ) as f :
1074
1115
for line in f :
1075
- test = line .strip ()
1076
- tags .append (TestId (test_file .path , test ))
1077
- return tags
1116
+ test , _ , keys = line .partition ('@' )
1117
+ test = test .strip ()
1118
+ keys = keys .strip ()
1119
+ if not keys :
1120
+ log (f'WARNING: invalid tag { test } : missing platform keys' )
1121
+ tags .append (Tag (
1122
+ TestId (test_file .path , test ),
1123
+ frozenset (keys .split (',' )),
1124
+ ))
1078
1125
return tags
1079
1126
1080
1127
0 commit comments