Skip to content

Commit 0baabeb

Browse files
authored
fix(ci): reworked smoke test batching (#15039)
1 parent 5c2ee31 commit 0baabeb

File tree

4 files changed

+1607
-193
lines changed

4 files changed

+1607
-193
lines changed

.github/workflows/docker-unified.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,8 @@ jobs:
376376
echo "cypress_batch_count=5" >> "$GITHUB_OUTPUT"
377377
echo "python_batch_count=3" >> "$GITHUB_OUTPUT"
378378
else
379-
echo "cypress_batch_count=11" >> "$GITHUB_OUTPUT"
380-
echo "python_batch_count=6" >> "$GITHUB_OUTPUT"
379+
echo "cypress_batch_count=8" >> "$GITHUB_OUTPUT"
380+
echo "python_batch_count=7" >> "$GITHUB_OUTPUT"
381381
fi
382382
383383
- id: set-matrix
@@ -422,6 +422,7 @@ jobs:
422422
MIXPANEL_PROJECT_ID: ${{ secrets.MIXPANEL_PROJECT_ID }}
423423
steps:
424424
- name: Free up disk space
425+
if: ${{ needs.setup.outputs.use_depot_cache != 'true' }}
425426
run: |
426427
sudo apt-get remove 'dotnet-*' azure-cli || true
427428
sudo rm -rf /usr/local/.ghcup || true

smoke-test/conftest.py

Lines changed: 107 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
import logging
12
import os
3+
import json
4+
from pathlib import Path
25

6+
from collections import defaultdict
37
import pytest
4-
from typing import List, Optional, Tuple
8+
from typing import Dict, List, Optional, Tuple
59
from _pytest.nodes import Item
610
import requests
711
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph, get_default_graph
@@ -18,6 +22,8 @@
1822
wait_for_writes_to_sync,
1923
)
2024

25+
logger = logging.getLogger(__name__)
26+
2127
# Disable telemetry
2228
os.environ["DATAHUB_TELEMETRY_ENABLED"] = "false"
2329
# Suppress logging manager to prevent I/O errors during pytest teardown
@@ -149,64 +155,121 @@ def bin_pack_tasks(tasks, n_buckets):
149155

150156
return buckets
151157

152-
def get_batch_start_end(num_tests: int) -> Tuple[int, int]:
153-
batch_count = env_vars.get_batch_count()
154158

155-
batch_number = env_vars.get_batch_number()
159+
def load_pytest_test_weights() -> Dict[str, float]:
160+
"""
161+
Load pytest test weights from JSON file.
156162
157-
if batch_count == 0 or batch_count > num_tests:
158-
raise ValueError(
159-
f"Invalid batch count {batch_count}: must be >0 and <= {num_tests} (num_tests)"
160-
)
161-
if batch_number >= batch_count:
162-
raise ValueError(
163-
f"Invalid batch number: {batch_number}, must be less than {batch_count} (zer0 based index)"
164-
)
163+
Returns:
164+
Dictionary mapping test IDs (classname::test_name) to durations in seconds.
165+
Returns empty dict if weights file doesn't exist.
166+
"""
167+
weights_file = Path(__file__).parent / "pytest_test_weights.json"
168+
169+
if not weights_file.exists():
170+
return {}
171+
172+
try:
173+
with open(weights_file) as f:
174+
weights_data = json.load(f)
175+
176+
# Convert to dict: {"test_e2e::test_gms_get_dataset": 262.807, ...}
177+
return {
178+
item["testId"]: float(item["duration"][:-1]) # Strip 's' suffix
179+
for item in weights_data
180+
}
181+
except Exception as e:
182+
logger.warning(f"Warning: Failed to load pytest test weights: {e}")
183+
return {}
184+
185+
186+
def aggregate_module_weights(items: List[Item], test_weights: Dict[str, float]) -> List[Tuple[str, List[Item], float]]:
187+
"""
188+
Group test items by module and aggregate their weights.
189+
190+
Args:
191+
items: List of pytest test items
192+
test_weights: Dictionary mapping test IDs to durations
193+
194+
Returns:
195+
List of (module_path, items_in_module, total_weight) tuples
196+
"""
197+
198+
# Group items by module (file path)
199+
modules: Dict[str, List[Item]] = defaultdict(list)
200+
for item in items:
201+
# Get the module path from the item's fspath
202+
module_path = str(item.fspath)
203+
modules[module_path].append(item)
165204

166-
batch_size = round(num_tests / batch_count)
205+
# Calculate total weight for each module
206+
module_data = []
207+
for module_path, module_items in modules.items():
208+
total_weight = 0.0
209+
for item in module_items:
210+
# Build test ID from nodeid
211+
# nodeid format: "tests/database/test_database.py::test_method"
212+
# weights format: "tests.database.test_database::test_method"
213+
nodeid = item.nodeid
167214

168-
batch_start = batch_size * batch_number
169-
batch_end = batch_start + batch_size
170-
# We must have exactly as many batches as specified by BATCH_COUNT.
171-
if (
172-
batch_number == batch_count - 1 # this is the last batch
173-
): # If ths is last batch put any remaining tests in the last batch.
174-
batch_end = num_tests
215+
# Convert path separators to dots and remove .py extension
216+
# tests/database/test_database.py::test_method -> tests.database.test_database::test_method
217+
test_id = nodeid.replace("/", ".").replace(".py::", "::")
175218

176-
if batch_count > 0:
177-
print(f"Running tests for batch {batch_number} of {batch_count}")
219+
weight = test_weights.get(test_id, 1.0) # Default to 1.0 if not found
220+
total_weight += weight
221+
222+
module_data.append((module_path, module_items, total_weight))
223+
224+
return module_data
178225

179-
return batch_start, batch_end
180226

181227
def pytest_collection_modifyitems(
182228
session: pytest.Session, config: pytest.Config, items: List[Item]
183229
) -> None:
184230
if env_vars.get_test_strategy() == "cypress":
185231
return # We launch cypress via pytests, but needs a different batching mechanism at cypress level.
186232

187-
# If BATCH_COUNT and BATCH_ENV vars are set, splits the pytests to batches and runs filters only the BATCH_NUMBER
188-
# batch for execution. Enables multiple parallel launches. Current implementation assumes all test are of equal
189-
# weight for batching. TODO. A weighted batching method can help make batches more equal sized by cost.
190-
# this effectively is a no-op if BATCH_COUNT=1
191-
start_index, end_index = get_batch_start_end(num_tests=len(items))
233+
# Get batch configuration
234+
batch_count_env = env_vars.get_batch_count()
235+
batch_count = int(batch_count_env)
236+
batch_number_env = env_vars.get_batch_number()
237+
batch_number = int(batch_number_env)
192238

193-
# Sort tests but preserve dependency order for library_examples tests
194-
# Library example tests should maintain their manifest order to respect dependencies
195-
library_example_tests = []
196-
other_tests = []
239+
if batch_count <= 1:
240+
# No batching needed
241+
return
197242

198-
for item in items:
199-
if "test_library_examples" in item.nodeid:
200-
library_example_tests.append(item)
201-
else:
202-
other_tests.append(item)
243+
# Load test weights
244+
test_weights = load_pytest_test_weights()
245+
246+
# Group items by module and aggregate weights
247+
module_data = aggregate_module_weights(items, test_weights)
248+
249+
# Sort modules by path for stability
250+
module_data.sort(key=lambda x: x[0])
251+
252+
# Create weighted tuples for bin-packing: (module_path, weight)
253+
# We'll also keep track of the items for each module
254+
module_map = {module_path: module_items for module_path, module_items, _ in module_data}
255+
weighted_modules = [(module_path, total_weight) for module_path, _, total_weight in module_data]
256+
257+
logger.info(f"Batching {len(items)} tests from {len(weighted_modules)} modules across {batch_count} batches")
258+
259+
# Apply bin-packing to modules
260+
module_batches = bin_pack_tasks(weighted_modules, batch_count)
261+
262+
# Get the modules for this batch
263+
selected_modules = module_batches[batch_number]
264+
265+
# Flatten back to individual test items
266+
# Tests within each module maintain their original collection order
267+
selected_items = []
268+
for module_path in selected_modules:
269+
selected_items.extend(module_map[module_path])
203270

204-
# Sort non-library tests alphabetically for stability
205-
other_tests.sort(key=lambda x: x.nodeid)
271+
logger.info(f"Batch {batch_number}: Running {len(selected_items)} tests from {len(selected_modules)} modules")
206272

207-
# Combine: library tests first (in original order), then other tests (sorted)
208-
items[:] = library_example_tests + other_tests
273+
# Replace items with the filtered list
274+
items[:] = selected_items
209275

210-
# replace items with the filtered list
211-
print(f"Running tests for batch {start_index}-{end_index}")
212-
items[:] = items[start_index:end_index]

0 commit comments

Comments
 (0)