Skip to content

Commit a23ffaa

Browse files
authored
pytest test sharding (#493)
Adds test sharding support to the pytest runner. I took some inspiration from https://github.com/caseyduquettesc/rules_python_pytest . In particular, I used it as a reference for which Bazel environment variables need to be pulled in the shim. The pytest plugin is a derivative work of https://github.com/AdamGleave/pytest-shard . That dependency hasn't been updated in a long time, and the sharding was based off hashing which often results in unbalanced shards or errors from empty shards. I vendored the code into the repo, modified it to use a plain round robin strategy for selection, and converted the code to be class based so it can be used with `pytest.main()`. ### Test plan I added an example to the examples directory and I tested it with my own repository here: https://github.com/vinnybod/bazel-examples/blob/python-test-sharding/test-shard-python/BUILD.bazel
1 parent 6ea32ea commit a23ffaa

File tree

10 files changed

+132
-2
lines changed

10 files changed

+132
-2
lines changed

examples/pytest/BUILD.bazel

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,18 @@ py_test(
5858
"@pypi_pytest//:pkg",
5959
],
6060
)
61+
62+
py_test(
63+
name = "sharding_test",
64+
srcs = [
65+
"__test__.py",
66+
"sharding_test.py",
67+
],
68+
imports = ["../.."],
69+
main = "__test__.py",
70+
package_collisions = "warning",
71+
shard_count = 2,
72+
deps = [
73+
"__test__",
74+
],
75+
)

examples/pytest/sharding_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
def test_shard_one():
2+
assert True
3+
4+
def test_shard_two():
5+
assert True

examples/virtual_deps/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ py_test(
6969
),
7070
deps = [
7171
requirement("pytest"),
72+
"__test__",
7273
":greet",
7374
],
7475
)

py/private/py_pytest_main.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,6 @@ def py_pytest_main(name, py_library = default_py_library, deps = [], data = [],
8888
srcs = [test_main],
8989
tags = tags,
9090
visibility = visibility,
91-
deps = deps,
91+
deps = deps + [Label("//py/private/pytest_shard")],
9292
data = data,
9393
)

py/private/pytest.py.tmpl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import sys
1616
import os
17+
from pathlib import Path
1718
from typing import List
1819

1920
import pytest
@@ -33,12 +34,15 @@ if "COVERAGE_MANIFEST" in os.environ:
3334
print("WARNING: python coverage setup failed. Do you need to include the 'coverage' library as a dependency of py_pytest_main?", e)
3435
pass
3536

37+
from pytest_shard import ShardPlugin
38+
3639
if __name__ == "__main__":
3740
# Change to the directory where we need to run the test or execute a no-op
3841
$$CHDIR$$
3942

4043
os.environ["ENV"] = "testing"
4144

45+
plugins = []
4246
args = [
4347
"--verbose",
4448
"--ignore=external/",
@@ -55,6 +59,20 @@ if __name__ == "__main__":
5559
if suite_name:
5660
args.extend(["-o", f"junit_suite_name={suite_name}"])
5761

62+
test_shard_index = os.environ.get("TEST_SHARD_INDEX")
63+
test_total_shards = os.environ.get("TEST_TOTAL_SHARDS")
64+
test_shard_status_file = os.environ.get("TEST_SHARD_STATUS_FILE")
65+
if (
66+
all([test_shard_index, test_total_shards, test_shard_status_file])
67+
and int(test_total_shards) > 1
68+
):
69+
args.extend([
70+
f"--shard-id={test_shard_index}",
71+
f"--num-shards={test_total_shards}",
72+
])
73+
Path(test_shard_status_file).touch()
74+
plugins.append(ShardPlugin())
75+
5876
test_filter = os.environ.get("TESTBRIDGE_TEST_ONLY")
5977
if test_filter is not None:
6078
args.append(f"-k={test_filter}")
@@ -67,7 +85,7 @@ if __name__ == "__main__":
6785
if len(cli_args) > 0:
6886
args.extend(cli_args)
6987

70-
exit_code = pytest.main(args)
88+
exit_code = pytest.main(args, plugins=plugins)
7189

7290
if exit_code != 0:
7391
print("Pytest exit code: " + str(exit_code), file=sys.stderr)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
load("@aspect_rules_py//py:defs.bzl", "py_library")
2+
3+
py_library(
4+
name = "pytest_shard",
5+
srcs = [
6+
"__init__.py",
7+
":pytest_shard.py",
8+
],
9+
imports = ["."],
10+
visibility = ["//visibility:public"],
11+
)

py/private/pytest_shard/LICENSE

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Copyright 2019 Adam Gleave
2+
3+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4+
5+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6+
7+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

py/private/pytest_shard/README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# pytest-shard
2+
3+
This is a fork of [pytest-shard](https://github.com/AdamGleave/pytest-shard) @ [4610a08](https://github.com/AdamGleave/pytest-shard/commit/64610a08dac6b0511b6d51cf895d0e1040d162ad)
4+
5+
## Changes
6+
7+
- The pytest hooks were moved into a class `ShardPlugin`, so that they can be loaded via `pytest.main`
8+
- The sharding strategy was changed to a simple round-robin strategy
9+
- The hash-bashed strategy was causing unbalanced or empty shards with small test sets

py/private/pytest_shard/__init__.py

Whitespace-only changes.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from typing import Iterable, List, Sequence
2+
3+
from _pytest import nodes # for type checking only
4+
5+
6+
def positive_int(x) -> int:
7+
x = int(x)
8+
if x < 0:
9+
raise ValueError(f"Argument {x} must be positive")
10+
return x
11+
12+
13+
def filter_items_by_shard(
14+
items: Iterable[nodes.Node], shard_id: int, num_shards: int
15+
) -> Sequence[nodes.Node]:
16+
"""Computes `items` that should be tested in `shard_id` out of `num_shards` total shards."""
17+
shards = [i % num_shards for i in range(len(items))]
18+
19+
new_items = []
20+
for shard, item in zip(shards, items):
21+
if shard == shard_id:
22+
new_items.append(item)
23+
return new_items
24+
25+
26+
class ShardPlugin:
27+
@staticmethod
28+
def pytest_addoption(parser):
29+
"""Add pytest-shard specific configuration parameters."""
30+
group = parser.getgroup("shard")
31+
group.addoption(
32+
"--shard-id",
33+
dest="shard_id",
34+
type=positive_int,
35+
default=0,
36+
help="Number of this shard.",
37+
)
38+
group.addoption(
39+
"--num-shards",
40+
dest="num_shards",
41+
type=positive_int,
42+
default=1,
43+
help="Total number of shards.",
44+
)
45+
46+
@staticmethod
47+
def pytest_report_collectionfinish(config, items: Sequence[nodes.Node]) -> str:
48+
"""Log how many and, if verbose, which items are tested in this shard."""
49+
msg = f"Running {len(items)} items in this shard"
50+
if config.option.verbose > 0 and config.getoption("num_shards") > 1:
51+
msg += ": " + ", ".join([item.nodeid for item in items])
52+
return msg
53+
54+
@staticmethod
55+
def pytest_collection_modifyitems(config, items: List[nodes.Node]):
56+
"""Mutate the collection to consist of just items to be tested in this shard."""
57+
shard_id = config.getoption("shard_id")
58+
shard_total = config.getoption("num_shards")
59+
if shard_id >= shard_total:
60+
raise ValueError(
61+
"shard_num = f{shard_num} must be less than shard_total = f{shard_total}"
62+
)
63+
64+
items[:] = filter_items_by_shard(items, shard_id, shard_total)

0 commit comments

Comments
 (0)