2828from ..utils .env_keys import REPORT_ERROR_KEY
2929from ..utils .fail_fast_mode import (FailFastModeValidateParams , fail_fast_mode_validate ,
3030 set_fail_fast_mode , warn_and_exit_if_fail_fast_mode )
31+ from ..utils .input_snapshot import InputSnapshotId
3132from ..utils .smart_tests_client import SmartTestsClient
32- from ..utils .typer_types import Duration , Percentage , parse_duration , parse_percentage
33+ from ..utils .typer_types import Duration , Fraction , Percentage , parse_duration , parse_fraction , parse_percentage
3334from .test_path_writer import TestPathWriter
3435
3536
@@ -174,6 +175,23 @@ def __init__(
174175 type = fileText (mode = "r" ),
175176 metavar = "FILE"
176177 )] = None ,
178+ input_snapshot_id : Annotated [InputSnapshotId | None , InputSnapshotId .as_option ()] = None ,
179+ print_input_snapshot_id : Annotated [bool , typer .Option (
180+ "--print-input-snapshot-id" ,
181+ help = "Print the input snapshot ID returned from the server instead of the subset results"
182+ )] = False ,
183+ bin_target : Annotated [Fraction | None , typer .Option (
184+ "--bin" ,
185+ help = "Split subset into bins, e.g. --bin 1/4" ,
186+ metavar = "INDEX/COUNT" ,
187+ type = parse_fraction
188+ )] = None ,
189+ same_bin_files : Annotated [List [str ], typer .Option (
190+ "--same-bin" ,
191+ help = "Keep all tests listed in the file together when splitting; one test per line" ,
192+ metavar = "FILE" ,
193+ multiple = True
194+ )] = [],
177195 is_get_tests_from_guess : Annotated [bool , typer .Option (
178196 "--get-tests-from-guess" ,
179197 help = "Get subset list from guessed tests"
@@ -255,9 +273,15 @@ def warn(msg: str):
255273 self .ignore_flaky_tests_above = ignore_flaky_tests_above
256274 self .prioritize_tests_failed_within_hours = prioritize_tests_failed_within_hours
257275 self .prioritized_tests_mapping_file = prioritized_tests_mapping_file
276+ self .input_snapshot_id = input_snapshot_id .value if input_snapshot_id else None
277+ self .print_input_snapshot_id = print_input_snapshot_id
278+ self .bin_target = bin_target
279+ self .same_bin_files = list (same_bin_files )
258280 self .is_get_tests_from_guess = is_get_tests_from_guess
259281 self .use_case = use_case
260282
283+ self ._validate_print_input_snapshot_option ()
284+
261285 self .file_path_normalizer = FilePathNormalizer (base_path , no_base_path_inference = no_base_path_inference )
262286
263287 self .test_paths : list [list [dict [str , str ]]] = []
@@ -305,7 +329,7 @@ def stdin(self) -> Iterable[str]:
305329 """
306330
307331 # To avoid the cli continue to wait from stdin
308- if self .is_get_tests_from_previous_sessions or self . is_get_tests_from_guess :
332+ if self ._should_skip_stdin () :
309333 return []
310334
311335 if sys .stdin .isatty ():
@@ -404,8 +428,103 @@ def get_payload(self) -> dict[str, Any]:
404428 if self .use_case :
405429 payload ['changesUnderTest' ] = self .use_case .value
406430
431+ if self .input_snapshot_id is not None :
432+ payload ['subsettingId' ] = self .input_snapshot_id
433+
434+ split_subset = self ._build_split_subset_payload ()
435+ if split_subset :
436+ payload ['splitSubset' ] = split_subset
437+
407438 return payload
408439
440+ def _build_split_subset_payload (self ) -> dict [str , Any ] | None :
441+ if self .bin_target is None :
442+ if self .same_bin_files :
443+ print_error_and_die (
444+ "--same-bin option requires --bin option.\n Please set --bin option to use --same-bin" ,
445+ self .tracking_client ,
446+ Tracking .ErrorEvent .USER_ERROR ,
447+ )
448+ return None
449+
450+ slice_index = self .bin_target .numerator
451+ slice_count = self .bin_target .denominator
452+
453+ if slice_index <= 0 or slice_count <= 0 :
454+ print_error_and_die (
455+ "Invalid --bin value. Both index and count must be positive integers." ,
456+ self .tracking_client ,
457+ Tracking .ErrorEvent .USER_ERROR ,
458+ )
459+
460+ if slice_count < slice_index :
461+ print_error_and_die (
462+ "Invalid --bin value. The numerator cannot exceed the denominator." ,
463+ self .tracking_client ,
464+ Tracking .ErrorEvent .USER_ERROR ,
465+ )
466+
467+ same_bins = self ._read_same_bin_files ()
468+
469+ return {
470+ "sliceIndex" : slice_index ,
471+ "sliceCount" : slice_count ,
472+ "sameBins" : same_bins ,
473+ }
474+
475+ def _read_same_bin_files (self ) -> list [list [TestPath ]]:
476+ if not self .same_bin_files :
477+ return []
478+
479+ formatter = self .same_bin_formatter
480+ if formatter is None :
481+ print_error_and_die (
482+ "--same-bin is not supported for this test runner." ,
483+ self .tracking_client ,
484+ Tracking .ErrorEvent .USER_ERROR ,
485+ )
486+
487+ same_bins : list [list [TestPath ]] = []
488+ seen_tests : set [str ] = set ()
489+
490+ for same_bin_file in self .same_bin_files :
491+ try :
492+ with open (same_bin_file , "r" , encoding = "utf-8" ) as fp :
493+ tests = [line .strip () for line in fp if line .strip ()]
494+ except OSError as exc :
495+ print_error_and_die (
496+ f"Failed to read --same-bin file '{ same_bin_file } ': { exc } " ,
497+ self .tracking_client ,
498+ Tracking .ErrorEvent .USER_ERROR ,
499+ )
500+
501+ unique_tests = list (dict .fromkeys (tests ))
502+
503+ group : list [TestPath ] = []
504+ for test in unique_tests :
505+ if test in seen_tests :
506+ print_error_and_die (
507+ f"Error: test '{ test } ' is listed in multiple --same-bin files." ,
508+ self .tracking_client ,
509+ Tracking .ErrorEvent .USER_ERROR ,
510+ )
511+ seen_tests .add (test )
512+
513+ # For type check
514+ assert formatter is not None , "--same -bin is not supported for this test runner"
515+ formatted = formatter (test )
516+ if not formatted :
517+ print_error_and_die (
518+ f"Failed to parse test '{ test } ' from --same-bin file { same_bin_file } " ,
519+ self .tracking_client ,
520+ Tracking .ErrorEvent .USER_ERROR ,
521+ )
522+ group .append (formatted )
523+
524+ same_bins .append (group )
525+
526+ return same_bins
527+
409528 def _collect_potential_test_files (self ):
410529 LOOSE_TEST_FILE_PATTERN = r'(\.(test|spec)\.|_test\.|Test\.|Spec\.|test/|tests/|__tests__/|src/test/)'
411530 EXCLUDE_PATTERN = r'(BUILD|Makefile|Dockerfile|LICENSE|.gitignore|.gitkeep|.keep|id_rsa|rsa|blank|taglib)|\.(xml|json|jsonl|txt|yml|yaml|toml|md|png|jpg|jpeg|gif|svg|sql|html|css|graphql|proto|gz|zip|rz|bzl|conf|config|snap|pem|crt|key|lock|jpi|hpi|jelly|properties|jar|ini|mod|sum|bmp|env|envrc|sh)$' # noqa E501
@@ -463,13 +582,75 @@ def request_subset(self) -> SubsetResult:
463582 e , "Warning: the service failed to subset. Falling back to running all tests" )
464583 return SubsetResult .from_test_paths (self .test_paths )
465584
585+ def _requires_test_input (self ) -> bool :
586+ return (
587+ self .input_snapshot_id is None
588+ and not self .is_get_tests_from_previous_sessions # noqa: W503
589+ and len (self .test_paths ) == 0 # noqa: W503
590+ )
591+
592+ def _should_skip_stdin (self ) -> bool :
593+ if self .is_get_tests_from_previous_sessions or self .is_get_tests_from_guess :
594+ return True
595+
596+ if self .input_snapshot_id is not None :
597+ if not sys .stdin .isatty ():
598+ warn_and_exit_if_fail_fast_mode (
599+ "Warning: --input-snapshot-id is set so stdin will be ignored."
600+ )
601+ return True
602+ return False
603+
604+ def _validate_print_input_snapshot_option (self ):
605+ if not self .print_input_snapshot_id :
606+ return
607+
608+ conflicts : list [str ] = []
609+ option_checks = [
610+ ("--target" , self .target is not None ),
611+ ("--time" , self .time is not None ),
612+ ("--confidence" , self .confidence is not None ),
613+ ("--goal-spec" , self .goal_spec is not None ),
614+ ("--rest" , self .rest is not None ),
615+ ("--bin" , self .bin_target is not None ),
616+ ("--same-bin" , bool (self .same_bin_files )),
617+ ("--ignore-new-tests" , self .ignore_new_tests ),
618+ ("--ignore-flaky-tests-above" , self .ignore_flaky_tests_above is not None ),
619+ ("--prioritize-tests-failed-within-hours" , self .prioritize_tests_failed_within_hours is not None ),
620+ ("--prioritized-tests-mapping" , self .prioritized_tests_mapping_file is not None ),
621+ ("--get-tests-from-previous-sessions" , self .is_get_tests_from_previous_sessions ),
622+ ("--get-tests-from-guess" , self .is_get_tests_from_guess ),
623+ ("--output-exclusion-rules" , self .is_output_exclusion_rules ),
624+ ("--non-blocking" , self .is_non_blocking ),
625+ ]
626+
627+ for option_name , is_set in option_checks :
628+ if is_set :
629+ conflicts .append (option_name )
630+
631+ if conflicts :
632+ conflict_list = ", " .join (conflicts )
633+ print_error_and_die (
634+ f"--print-input-snapshot-id cannot be used with { conflict_list } ." ,
635+ self .tracking_client ,
636+ Tracking .ErrorEvent .USER_ERROR ,
637+ )
638+
639+ def _print_input_snapshot_id_value (self , subset_result : SubsetResult ):
640+ if not subset_result .subset_id :
641+ raise click .ClickException (
642+ "Subset request did not return an input snapshot ID. Please re-run the command."
643+ )
644+
645+ click .echo (subset_result .subset_id )
646+
466647 def run (self ):
467648 """called after tests are scanned to compute the optimized order"""
468649
469650 if self .is_get_tests_from_guess :
470651 self ._collect_potential_test_files ()
471652
472- if not self .is_get_tests_from_previous_sessions and len ( self . test_paths ) == 0 :
653+ if self ._requires_test_input () :
473654 if self .input_given :
474655 print_error_and_die ("ERROR: Given arguments did not match any tests. They appear to be incorrect/non-existent." , tracking_client , Tracking .ErrorEvent .USER_ERROR ) # noqa E501
475656 else :
@@ -488,6 +669,12 @@ def run(self):
488669
489670 if len (subset_result .subset ) == 0 :
490671 warn_and_exit_if_fail_fast_mode ("Error: no tests found matching the path." )
672+ if self .print_input_snapshot_id :
673+ self ._print_input_snapshot_id_value (subset_result )
674+ return
675+
676+ if self .print_input_snapshot_id :
677+ self ._print_input_snapshot_id_value (subset_result )
491678 return
492679
493680 # TODO(Konboi): split subset isn't provided for smart-tests initial release
0 commit comments