Skip to content

Commit 891fd6a

Browse files
authored
Merge pull request #1217 from cloudbees-oss/support-split-option/impl
Support split subsetting
2 parents 60a888f + 02169a5 commit 891fd6a

File tree

7 files changed

+443
-7
lines changed

7 files changed

+443
-7
lines changed

smart_tests/commands/subset.py

Lines changed: 190 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@
2828
from ..utils.env_keys import REPORT_ERROR_KEY
2929
from ..utils.fail_fast_mode import (FailFastModeValidateParams, fail_fast_mode_validate,
3030
set_fail_fast_mode, warn_and_exit_if_fail_fast_mode)
31+
from ..utils.input_snapshot import InputSnapshotId
3132
from ..utils.smart_tests_client import SmartTestsClient
32-
from ..utils.typer_types import Duration, Percentage, parse_duration, parse_percentage
33+
from ..utils.typer_types import Duration, Fraction, Percentage, parse_duration, parse_fraction, parse_percentage
3334
from .test_path_writer import TestPathWriter
3435

3536

@@ -174,6 +175,23 @@ def __init__(
174175
type=fileText(mode="r"),
175176
metavar="FILE"
176177
)] = None,
178+
input_snapshot_id: Annotated[InputSnapshotId | None, InputSnapshotId.as_option()] = None,
179+
print_input_snapshot_id: Annotated[bool, typer.Option(
180+
"--print-input-snapshot-id",
181+
help="Print the input snapshot ID returned from the server instead of the subset results"
182+
)] = False,
183+
bin_target: Annotated[Fraction | None, typer.Option(
184+
"--bin",
185+
help="Split subset into bins, e.g. --bin 1/4",
186+
metavar="INDEX/COUNT",
187+
type=parse_fraction
188+
)] = None,
189+
same_bin_files: Annotated[List[str], typer.Option(
190+
"--same-bin",
191+
help="Keep all tests listed in the file together when splitting; one test per line",
192+
metavar="FILE",
193+
multiple=True
194+
)] = [],
177195
is_get_tests_from_guess: Annotated[bool, typer.Option(
178196
"--get-tests-from-guess",
179197
help="Get subset list from guessed tests"
@@ -255,9 +273,15 @@ def warn(msg: str):
255273
self.ignore_flaky_tests_above = ignore_flaky_tests_above
256274
self.prioritize_tests_failed_within_hours = prioritize_tests_failed_within_hours
257275
self.prioritized_tests_mapping_file = prioritized_tests_mapping_file
276+
self.input_snapshot_id = input_snapshot_id.value if input_snapshot_id else None
277+
self.print_input_snapshot_id = print_input_snapshot_id
278+
self.bin_target = bin_target
279+
self.same_bin_files = list(same_bin_files)
258280
self.is_get_tests_from_guess = is_get_tests_from_guess
259281
self.use_case = use_case
260282

283+
self._validate_print_input_snapshot_option()
284+
261285
self.file_path_normalizer = FilePathNormalizer(base_path, no_base_path_inference=no_base_path_inference)
262286

263287
self.test_paths: list[list[dict[str, str]]] = []
@@ -305,7 +329,7 @@ def stdin(self) -> Iterable[str]:
305329
"""
306330

307331
# To avoid the cli continue to wait from stdin
308-
if self.is_get_tests_from_previous_sessions or self.is_get_tests_from_guess:
332+
if self._should_skip_stdin():
309333
return []
310334

311335
if sys.stdin.isatty():
@@ -404,8 +428,103 @@ def get_payload(self) -> dict[str, Any]:
404428
if self.use_case:
405429
payload['changesUnderTest'] = self.use_case.value
406430

431+
if self.input_snapshot_id is not None:
432+
payload['subsettingId'] = self.input_snapshot_id
433+
434+
split_subset = self._build_split_subset_payload()
435+
if split_subset:
436+
payload['splitSubset'] = split_subset
437+
407438
return payload
408439

440+
def _build_split_subset_payload(self) -> dict[str, Any] | None:
441+
if self.bin_target is None:
442+
if self.same_bin_files:
443+
print_error_and_die(
444+
"--same-bin option requires --bin option.\nPlease set --bin option to use --same-bin",
445+
self.tracking_client,
446+
Tracking.ErrorEvent.USER_ERROR,
447+
)
448+
return None
449+
450+
slice_index = self.bin_target.numerator
451+
slice_count = self.bin_target.denominator
452+
453+
if slice_index <= 0 or slice_count <= 0:
454+
print_error_and_die(
455+
"Invalid --bin value. Both index and count must be positive integers.",
456+
self.tracking_client,
457+
Tracking.ErrorEvent.USER_ERROR,
458+
)
459+
460+
if slice_count < slice_index:
461+
print_error_and_die(
462+
"Invalid --bin value. The numerator cannot exceed the denominator.",
463+
self.tracking_client,
464+
Tracking.ErrorEvent.USER_ERROR,
465+
)
466+
467+
same_bins = self._read_same_bin_files()
468+
469+
return {
470+
"sliceIndex": slice_index,
471+
"sliceCount": slice_count,
472+
"sameBins": same_bins,
473+
}
474+
475+
def _read_same_bin_files(self) -> list[list[TestPath]]:
476+
if not self.same_bin_files:
477+
return []
478+
479+
formatter = self.same_bin_formatter
480+
if formatter is None:
481+
print_error_and_die(
482+
"--same-bin is not supported for this test runner.",
483+
self.tracking_client,
484+
Tracking.ErrorEvent.USER_ERROR,
485+
)
486+
487+
same_bins: list[list[TestPath]] = []
488+
seen_tests: set[str] = set()
489+
490+
for same_bin_file in self.same_bin_files:
491+
try:
492+
with open(same_bin_file, "r", encoding="utf-8") as fp:
493+
tests = [line.strip() for line in fp if line.strip()]
494+
except OSError as exc:
495+
print_error_and_die(
496+
f"Failed to read --same-bin file '{same_bin_file}': {exc}",
497+
self.tracking_client,
498+
Tracking.ErrorEvent.USER_ERROR,
499+
)
500+
501+
unique_tests = list(dict.fromkeys(tests))
502+
503+
group: list[TestPath] = []
504+
for test in unique_tests:
505+
if test in seen_tests:
506+
print_error_and_die(
507+
f"Error: test '{test}' is listed in multiple --same-bin files.",
508+
self.tracking_client,
509+
Tracking.ErrorEvent.USER_ERROR,
510+
)
511+
seen_tests.add(test)
512+
513+
# For type check
514+
assert formatter is not None, "--same -bin is not supported for this test runner"
515+
formatted = formatter(test)
516+
if not formatted:
517+
print_error_and_die(
518+
f"Failed to parse test '{test}' from --same-bin file {same_bin_file}",
519+
self.tracking_client,
520+
Tracking.ErrorEvent.USER_ERROR,
521+
)
522+
group.append(formatted)
523+
524+
same_bins.append(group)
525+
526+
return same_bins
527+
409528
def _collect_potential_test_files(self):
410529
LOOSE_TEST_FILE_PATTERN = r'(\.(test|spec)\.|_test\.|Test\.|Spec\.|test/|tests/|__tests__/|src/test/)'
411530
EXCLUDE_PATTERN = r'(BUILD|Makefile|Dockerfile|LICENSE|.gitignore|.gitkeep|.keep|id_rsa|rsa|blank|taglib)|\.(xml|json|jsonl|txt|yml|yaml|toml|md|png|jpg|jpeg|gif|svg|sql|html|css|graphql|proto|gz|zip|rz|bzl|conf|config|snap|pem|crt|key|lock|jpi|hpi|jelly|properties|jar|ini|mod|sum|bmp|env|envrc|sh)$' # noqa E501
@@ -463,13 +582,75 @@ def request_subset(self) -> SubsetResult:
463582
e, "Warning: the service failed to subset. Falling back to running all tests")
464583
return SubsetResult.from_test_paths(self.test_paths)
465584

585+
def _requires_test_input(self) -> bool:
586+
return (
587+
self.input_snapshot_id is None
588+
and not self.is_get_tests_from_previous_sessions # noqa: W503
589+
and len(self.test_paths) == 0 # noqa: W503
590+
)
591+
592+
def _should_skip_stdin(self) -> bool:
593+
if self.is_get_tests_from_previous_sessions or self.is_get_tests_from_guess:
594+
return True
595+
596+
if self.input_snapshot_id is not None:
597+
if not sys.stdin.isatty():
598+
warn_and_exit_if_fail_fast_mode(
599+
"Warning: --input-snapshot-id is set so stdin will be ignored."
600+
)
601+
return True
602+
return False
603+
604+
def _validate_print_input_snapshot_option(self):
605+
if not self.print_input_snapshot_id:
606+
return
607+
608+
conflicts: list[str] = []
609+
option_checks = [
610+
("--target", self.target is not None),
611+
("--time", self.time is not None),
612+
("--confidence", self.confidence is not None),
613+
("--goal-spec", self.goal_spec is not None),
614+
("--rest", self.rest is not None),
615+
("--bin", self.bin_target is not None),
616+
("--same-bin", bool(self.same_bin_files)),
617+
("--ignore-new-tests", self.ignore_new_tests),
618+
("--ignore-flaky-tests-above", self.ignore_flaky_tests_above is not None),
619+
("--prioritize-tests-failed-within-hours", self.prioritize_tests_failed_within_hours is not None),
620+
("--prioritized-tests-mapping", self.prioritized_tests_mapping_file is not None),
621+
("--get-tests-from-previous-sessions", self.is_get_tests_from_previous_sessions),
622+
("--get-tests-from-guess", self.is_get_tests_from_guess),
623+
("--output-exclusion-rules", self.is_output_exclusion_rules),
624+
("--non-blocking", self.is_non_blocking),
625+
]
626+
627+
for option_name, is_set in option_checks:
628+
if is_set:
629+
conflicts.append(option_name)
630+
631+
if conflicts:
632+
conflict_list = ", ".join(conflicts)
633+
print_error_and_die(
634+
f"--print-input-snapshot-id cannot be used with {conflict_list}.",
635+
self.tracking_client,
636+
Tracking.ErrorEvent.USER_ERROR,
637+
)
638+
639+
def _print_input_snapshot_id_value(self, subset_result: SubsetResult):
640+
if not subset_result.subset_id:
641+
raise click.ClickException(
642+
"Subset request did not return an input snapshot ID. Please re-run the command."
643+
)
644+
645+
click.echo(subset_result.subset_id)
646+
466647
def run(self):
467648
"""called after tests are scanned to compute the optimized order"""
468649

469650
if self.is_get_tests_from_guess:
470651
self._collect_potential_test_files()
471652

472-
if not self.is_get_tests_from_previous_sessions and len(self.test_paths) == 0:
653+
if self._requires_test_input():
473654
if self.input_given:
474655
print_error_and_die("ERROR: Given arguments did not match any tests. They appear to be incorrect/non-existent.", tracking_client, Tracking.ErrorEvent.USER_ERROR) # noqa E501
475656
else:
@@ -488,6 +669,12 @@ def run(self):
488669

489670
if len(subset_result.subset) == 0:
490671
warn_and_exit_if_fail_fast_mode("Error: no tests found matching the path.")
672+
if self.print_input_snapshot_id:
673+
self._print_input_snapshot_id_value(subset_result)
674+
return
675+
676+
if self.print_input_snapshot_id:
677+
self._print_input_snapshot_id_value(subset_result)
491678
return
492679

493680
# TODO(Konboi): split subset isn't provided for smart-tests initial release

smart_tests/commands/test_path_writer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from os.path import join
2-
from typing import Callable, Dict, List
2+
from typing import Callable, List
33

44
import click
55

@@ -19,7 +19,7 @@ class TestPathWriter(object):
1919

2020
def __init__(self, app: Application):
2121
self.formatter = self.default_formatter
22-
self._same_bin_formatter: Callable[[str], Dict[str, str]] | None = None
22+
self._same_bin_formatter: Callable[[str], TestPath] | None = None
2323
self.separator = "\n"
2424
self.app = app
2525

@@ -43,9 +43,9 @@ def print(self, test_paths: List[TestPath]):
4343
for t in test_paths))
4444

4545
@property
46-
def same_bin_formatter(self) -> Callable[[str], Dict[str, str]] | None:
46+
def same_bin_formatter(self) -> Callable[[str], TestPath] | None:
4747
return self._same_bin_formatter
4848

4949
@same_bin_formatter.setter
50-
def same_bin_formatter(self, v: Callable[[str], Dict[str, str]]):
50+
def same_bin_formatter(self, v: Callable[[str], TestPath]):
5151
self._same_bin_formatter = v

smart_tests/test_runners/go_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def subset(client: Subset):
4141
test_cases = []
4242
client.formatter = lambda x: f"^{x[1]['name']}$"
4343
client.separator = '|'
44+
client.same_bin_formatter = format_same_bin
4445
client.run()
4546

4647

smart_tests/test_runners/gradle.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def exclusion_output_handler(subset_tests, rest_tests):
7777
client.formatter = lambda x: f"--tests {x[0]['name']}"
7878
client.separator = ' '
7979

80+
client.same_bin_formatter = lambda s: [{"type": "class", "name": s}]
81+
8082
client.run()
8183

8284

smart_tests/test_runners/maven.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ def file2test(f: str) -> List | None:
126126
for root in source_roots:
127127
client.scan(root, '**/*', file2test)
128128

129+
client.same_bin_formatter = lambda s: [{"type": "class", "name": s}]
130+
129131
client.run()
130132

131133

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""Utility type for --input-snapshot-id option."""
2+
3+
import click
4+
5+
from smart_tests.args4p import typer
6+
7+
8+
class InputSnapshotId:
9+
"""Parses either a numeric snapshot ID or @path reference."""
10+
11+
def __init__(self, raw: str):
12+
value = str(raw)
13+
if value.startswith('@'):
14+
file_path = value[1:]
15+
try:
16+
with open(file_path, 'r', encoding='utf-8') as fp:
17+
value = fp.read().strip()
18+
except OSError as exc:
19+
raise click.BadParameter(
20+
f"Failed to read input snapshot ID file '{file_path}': {exc}"
21+
)
22+
23+
try:
24+
parsed = int(value)
25+
except ValueError:
26+
raise click.BadParameter(
27+
f"Invalid input snapshot ID '{value}'. Expected a positive integer."
28+
)
29+
30+
if parsed < 1:
31+
raise click.BadParameter(
32+
"Invalid input snapshot ID. Expected a positive integer."
33+
)
34+
35+
self.value = parsed
36+
37+
def __int__(self) -> int:
38+
return self.value
39+
40+
def __str__(self) -> str:
41+
return str(self.value)
42+
43+
@staticmethod
44+
def as_option():
45+
return typer.Option(
46+
"--input-snapshot-id",
47+
help="Reuse reorder results from an existing input snapshot ID or specify @path/to/file to load it",
48+
metavar="ID|@FILE",
49+
type=InputSnapshotId,
50+
)

0 commit comments

Comments
 (0)