@@ -994,10 +994,10 @@ def summary_failures_short(tr):
994994 config .option .tbstyle = orig_tbstyle
995995
996996
997- # Copied from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
997+ # Adapted from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
998998def is_flaky (max_attempts : int = 5 , wait_before_retry : Optional [float ] = None , description : Optional [str ] = None ):
999999 """
1000- To decorate flaky tests. They will be retried on failures.
1000+ To decorate flaky tests (methods or entire classes) . They will be retried on failures.
10011001
10021002 Args:
10031003 max_attempts (`int`, *optional*, defaults to 5):
@@ -1009,22 +1009,33 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, d
10091009 etc.)
10101010 """
10111011
1012- def decorator (test_func_ref ):
1013- @functools .wraps (test_func_ref )
1012+ def decorator (obj ):
1013+ # If decorating a class, wrap each test method on it
1014+ if inspect .isclass (obj ):
1015+ for attr_name , attr_value in list (obj .__dict__ .items ()):
1016+ if callable (attr_value ) and attr_name .startswith ("test" ):
1017+ # recursively decorate the method
1018+ setattr (obj , attr_name , decorator (attr_value ))
1019+ return obj
1020+
1021+ # Otherwise we're decorating a single test function / method
1022+ @functools .wraps (obj )
10141023 def wrapper (* args , ** kwargs ):
10151024 retry_count = 1
1016-
10171025 while retry_count < max_attempts :
10181026 try :
1019- return test_func_ref (* args , ** kwargs )
1020-
1027+ return obj (* args , ** kwargs )
10211028 except Exception as err :
1022- print (f"Test failed with { err } at try { retry_count } /{ max_attempts } ." , file = sys .stderr )
1029+ msg = (
1030+ f"[FLAKY] { description or obj .__name__ !r} "
1031+ f"failed on attempt { retry_count } /{ max_attempts } : { err } "
1032+ )
1033+ print (msg , file = sys .stderr )
10231034 if wait_before_retry is not None :
10241035 time .sleep (wait_before_retry )
10251036 retry_count += 1
10261037
1027- return test_func_ref (* args , ** kwargs )
1038+ return obj (* args , ** kwargs )
10281039
10291040 return wrapper
10301041
0 commit comments