|
| 1 | +import logging |
1 | 2 | import os |
| 3 | +import json |
| 4 | +from pathlib import Path |
2 | 5 |
|
| 6 | +from collections import defaultdict |
3 | 7 | import pytest |
4 | | -from typing import List, Optional, Tuple |
| 8 | +from typing import Dict, List, Optional, Tuple |
5 | 9 | from _pytest.nodes import Item |
6 | 10 | import requests |
7 | 11 | from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph, get_default_graph |
|
18 | 22 | wait_for_writes_to_sync, |
19 | 23 | ) |
20 | 24 |
|
| 25 | +logger = logging.getLogger(__name__) |
| 26 | + |
21 | 27 | # Disable telemetry |
22 | 28 | os.environ["DATAHUB_TELEMETRY_ENABLED"] = "false" |
23 | 29 | # Suppress logging manager to prevent I/O errors during pytest teardown |
@@ -149,64 +155,121 @@ def bin_pack_tasks(tasks, n_buckets): |
149 | 155 |
|
150 | 156 | return buckets |
151 | 157 |
|
152 | | -def get_batch_start_end(num_tests: int) -> Tuple[int, int]: |
153 | | - batch_count = env_vars.get_batch_count() |
154 | 158 |
|
155 | | - batch_number = env_vars.get_batch_number() |
| 159 | +def load_pytest_test_weights() -> Dict[str, float]: |
| 160 | + """ |
| 161 | + Load pytest test weights from JSON file. |
156 | 162 |
|
157 | | - if batch_count == 0 or batch_count > num_tests: |
158 | | - raise ValueError( |
159 | | - f"Invalid batch count {batch_count}: must be >0 and <= {num_tests} (num_tests)" |
160 | | - ) |
161 | | - if batch_number >= batch_count: |
162 | | - raise ValueError( |
163 | | - f"Invalid batch number: {batch_number}, must be less than {batch_count} (zer0 based index)" |
164 | | - ) |
| 163 | + Returns: |
| 164 | + Dictionary mapping test IDs (classname::test_name) to durations in seconds. |
| 165 | + Returns empty dict if weights file doesn't exist. |
| 166 | + """ |
| 167 | + weights_file = Path(__file__).parent / "pytest_test_weights.json" |
| 168 | + |
| 169 | + if not weights_file.exists(): |
| 170 | + return {} |
| 171 | + |
| 172 | + try: |
| 173 | + with open(weights_file) as f: |
| 174 | + weights_data = json.load(f) |
| 175 | + |
| 176 | + # Convert to dict: {"test_e2e::test_gms_get_dataset": 262.807, ...} |
| 177 | + return { |
| 178 | + item["testId"]: float(item["duration"][:-1]) # Strip 's' suffix |
| 179 | + for item in weights_data |
| 180 | + } |
| 181 | + except Exception as e: |
| 182 | + logger.warning(f"Warning: Failed to load pytest test weights: {e}") |
| 183 | + return {} |
| 184 | + |
| 185 | + |
| 186 | +def aggregate_module_weights(items: List[Item], test_weights: Dict[str, float]) -> List[Tuple[str, List[Item], float]]: |
| 187 | + """ |
| 188 | + Group test items by module and aggregate their weights. |
| 189 | +
|
| 190 | + Args: |
| 191 | + items: List of pytest test items |
| 192 | + test_weights: Dictionary mapping test IDs to durations |
| 193 | +
|
| 194 | + Returns: |
| 195 | + List of (module_path, items_in_module, total_weight) tuples |
| 196 | + """ |
| 197 | + |
| 198 | + # Group items by module (file path) |
| 199 | + modules: Dict[str, List[Item]] = defaultdict(list) |
| 200 | + for item in items: |
| 201 | + # Get the module path from the item's fspath |
| 202 | + module_path = str(item.fspath) |
| 203 | + modules[module_path].append(item) |
165 | 204 |
|
166 | | - batch_size = round(num_tests / batch_count) |
| 205 | + # Calculate total weight for each module |
| 206 | + module_data = [] |
| 207 | + for module_path, module_items in modules.items(): |
| 208 | + total_weight = 0.0 |
| 209 | + for item in module_items: |
| 210 | + # Build test ID from nodeid |
| 211 | + # nodeid format: "tests/database/test_database.py::test_method" |
| 212 | + # weights format: "tests.database.test_database::test_method" |
| 213 | + nodeid = item.nodeid |
167 | 214 |
|
168 | | - batch_start = batch_size * batch_number |
169 | | - batch_end = batch_start + batch_size |
170 | | - # We must have exactly as many batches as specified by BATCH_COUNT. |
171 | | - if ( |
172 | | - batch_number == batch_count - 1 # this is the last batch |
173 | | - ): # If ths is last batch put any remaining tests in the last batch. |
174 | | - batch_end = num_tests |
| 215 | + # Convert path separators to dots and remove .py extension |
| 216 | + # tests/database/test_database.py::test_method -> tests.database.test_database::test_method |
| 217 | + test_id = nodeid.replace("/", ".").replace(".py::", "::") |
175 | 218 |
|
176 | | - if batch_count > 0: |
177 | | - print(f"Running tests for batch {batch_number} of {batch_count}") |
| 219 | + weight = test_weights.get(test_id, 1.0) # Default to 1.0 if not found |
| 220 | + total_weight += weight |
| 221 | + |
| 222 | + module_data.append((module_path, module_items, total_weight)) |
| 223 | + |
| 224 | + return module_data |
178 | 225 |
|
179 | | - return batch_start, batch_end |
180 | 226 |
|
181 | 227 | def pytest_collection_modifyitems( |
182 | 228 | session: pytest.Session, config: pytest.Config, items: List[Item] |
183 | 229 | ) -> None: |
184 | 230 | if env_vars.get_test_strategy() == "cypress": |
185 | 231 | return # We launch cypress via pytests, but needs a different batching mechanism at cypress level. |
186 | 232 |
|
187 | | - # If BATCH_COUNT and BATCH_ENV vars are set, splits the pytests to batches and runs filters only the BATCH_NUMBER |
188 | | - # batch for execution. Enables multiple parallel launches. Current implementation assumes all test are of equal |
189 | | - # weight for batching. TODO. A weighted batching method can help make batches more equal sized by cost. |
190 | | - # this effectively is a no-op if BATCH_COUNT=1 |
191 | | - start_index, end_index = get_batch_start_end(num_tests=len(items)) |
| 233 | + # Get batch configuration |
| 234 | + batch_count_env = env_vars.get_batch_count() |
| 235 | + batch_count = int(batch_count_env) |
| 236 | + batch_number_env = env_vars.get_batch_number() |
| 237 | + batch_number = int(batch_number_env) |
192 | 238 |
|
193 | | - # Sort tests but preserve dependency order for library_examples tests |
194 | | - # Library example tests should maintain their manifest order to respect dependencies |
195 | | - library_example_tests = [] |
196 | | - other_tests = [] |
| 239 | + if batch_count <= 1: |
| 240 | + # No batching needed |
| 241 | + return |
197 | 242 |
|
198 | | - for item in items: |
199 | | - if "test_library_examples" in item.nodeid: |
200 | | - library_example_tests.append(item) |
201 | | - else: |
202 | | - other_tests.append(item) |
| 243 | + # Load test weights |
| 244 | + test_weights = load_pytest_test_weights() |
| 245 | + |
| 246 | + # Group items by module and aggregate weights |
| 247 | + module_data = aggregate_module_weights(items, test_weights) |
| 248 | + |
| 249 | + # Sort modules by path for stability |
| 250 | + module_data.sort(key=lambda x: x[0]) |
| 251 | + |
| 252 | + # Create weighted tuples for bin-packing: (module_path, weight) |
| 253 | + # We'll also keep track of the items for each module |
| 254 | + module_map = {module_path: module_items for module_path, module_items, _ in module_data} |
| 255 | + weighted_modules = [(module_path, total_weight) for module_path, _, total_weight in module_data] |
| 256 | + |
| 257 | + logger.info(f"Batching {len(items)} tests from {len(weighted_modules)} modules across {batch_count} batches") |
| 258 | + |
| 259 | + # Apply bin-packing to modules |
| 260 | + module_batches = bin_pack_tasks(weighted_modules, batch_count) |
| 261 | + |
| 262 | + # Get the modules for this batch |
| 263 | + selected_modules = module_batches[batch_number] |
| 264 | + |
| 265 | + # Flatten back to individual test items |
| 266 | + # Tests within each module maintain their original collection order |
| 267 | + selected_items = [] |
| 268 | + for module_path in selected_modules: |
| 269 | + selected_items.extend(module_map[module_path]) |
203 | 270 |
|
204 | | - # Sort non-library tests alphabetically for stability |
205 | | - other_tests.sort(key=lambda x: x.nodeid) |
| 271 | + logger.info(f"Batch {batch_number}: Running {len(selected_items)} tests from {len(selected_modules)} modules") |
206 | 272 |
|
207 | | - # Combine: library tests first (in original order), then other tests (sorted) |
208 | | - items[:] = library_example_tests + other_tests |
| 273 | + # Replace items with the filtered list |
| 274 | + items[:] = selected_items |
209 | 275 |
|
210 | | - # replace items with the filtered list |
211 | | - print(f"Running tests for batch {start_index}-{end_index}") |
212 | | - items[:] = items[start_index:end_index] |
|
0 commit comments