1212
1313template = """
1414pytest.{}:
15- extends: .pytest
15+ extends: {}
1616 variables:
1717 PYTESTFILE: {}
1818 EXAMPLEMODEL: {}
2828
2929# Long-running tests will not be bundled with other tests
3030LONGLIST = {'test_hgq_layers' , 'test_hgq_players' , 'test_qkeras' , 'test_pytorch_api' }
31+ KERAS3_LIST = {'test_keras_v3_api' , 'test_hgq2_mha' , 'test_einsum_dense' , 'test_qeinsum' , 'test_multiout_onnx' }
3132
3233# Test files to split by individual test cases
3334# Value = chunk size per CI job
@@ -71,7 +72,7 @@ def generate_test_yaml(test_root='.'):
7172 test_paths = [
7273 path
7374 for path in test_root .glob ('**/test_*.py' )
74- if path .stem not in (BLACKLIST | LONGLIST | set (SPLIT_BY_TEST_CASE .keys ()))
75+ if path .stem not in (BLACKLIST | LONGLIST | set (SPLIT_BY_TEST_CASE .keys ()) | KERAS3_LIST )
7576 ]
7677 need_example_models = [uses_example_model (path ) for path in test_paths ]
7778
@@ -85,7 +86,7 @@ def generate_test_yaml(test_root='.'):
8586 name = '+' .join (names )
8687 test_files = ' ' .join ([str (path .relative_to (test_root )) for path in batch_paths ])
8788 batch_need_example_model = int (any ([need_example_models [i ] for i in batch_idxs ]))
88- diff_yml = yaml .safe_load (template .format (name , test_files , batch_need_example_model ))
89+ diff_yml = yaml .safe_load (template .format (name , '.pytest' , test_files , batch_need_example_model ))
8990 if yml is None :
9091 yml = diff_yml
9192 else :
@@ -96,7 +97,7 @@ def generate_test_yaml(test_root='.'):
9697 name = path .stem .replace ('test_' , '' )
9798 test_file = str (path .relative_to (test_root ))
9899 needs_examples = uses_example_model (path )
99- diff_yml = yaml .safe_load (template .format (name , test_file , int (needs_examples )))
100+ diff_yml = yaml .safe_load (template .format (name , '.pytest' , test_file , int (needs_examples )))
100101 yml .update (diff_yml )
101102
102103 test_paths = [path for path in test_root .glob ('**/test_*.py' ) if path .stem in SPLIT_BY_TEST_CASE ]
@@ -111,12 +112,27 @@ def generate_test_yaml(test_root='.'):
111112 for i , batch in enumerate (batched (test_ids , chunk_size )):
112113 job_name = f'{ name_base } _part{ i } '
113114 test_file_args = ' ' .join (batch ).strip ().replace ('\n ' , ' ' )
114- diff_yml = yaml .safe_load (template .format (job_name , test_file_args , int (needs_examples )))
115+ diff_yml = yaml .safe_load (template .format (job_name , '.pytest' , test_file_args , int (needs_examples )))
115116 if yml is None :
116117 yml = diff_yml
117118 else :
118119 yml .update (diff_yml )
119120
121+ keras3_paths = [path for path in test_root .glob ('**/test_*.py' ) if path .stem in KERAS3_LIST ]
122+ keras3_need_examples = [uses_example_model (path ) for path in keras3_paths ]
123+
124+ k3_idxs = list (range (len (keras3_need_examples )))
125+ k3_idxs = sorted (k3_idxs , key = lambda i : f'{ keras3_need_examples [i ]} _{ path_to_name (keras3_paths [i ])} ' )
126+
127+ for batch_idxs in batched (k3_idxs , n_test_files_per_yml ):
128+ batch_paths : list [Path ] = [keras3_paths [i ] for i in batch_idxs ]
129+ names = [path_to_name (path ) for path in batch_paths ]
130+ name = 'keras3-' + '+' .join (names )
131+ test_files = ' ' .join ([str (path .relative_to (test_root )) for path in batch_paths ])
132+ batch_need_example_model = int (any ([keras3_need_examples [i ] for i in batch_idxs ]))
133+ diff_yml = yaml .safe_load (template .format (name , '.pytest-keras3-only' , test_files , batch_need_example_model ))
134+ yml .update (diff_yml )
135+
120136 return yml
121137
122138
0 commit comments