diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index ebb3d7055319..a136d1b6bdff 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -994,10 +994,10 @@ def summary_failures_short(tr): config.option.tbstyle = orig_tbstyle -# Copied from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905 +# Adapted from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905 def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None): """ - To decorate flaky tests. They will be retried on failures. + To decorate flaky tests (methods or entire classes). They will be retried on failures. Args: max_attempts (`int`, *optional*, defaults to 5): @@ -1009,22 +1009,33 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, d etc.) """ - def decorator(test_func_ref): - @functools.wraps(test_func_ref) + def decorator(obj): + # If decorating a class, wrap each test method on it + if inspect.isclass(obj): + for attr_name, attr_value in list(obj.__dict__.items()): + if callable(attr_value) and attr_name.startswith("test"): + # recursively decorate the method + setattr(obj, attr_name, decorator(attr_value)) + return obj + + # Otherwise we're decorating a single test function / method + @functools.wraps(obj) def wrapper(*args, **kwargs): retry_count = 1 - while retry_count < max_attempts: try: - return test_func_ref(*args, **kwargs) - + return obj(*args, **kwargs) except Exception as err: - print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr) + msg = ( + f"[FLAKY] {description or obj.__name__!r} " + f"failed on attempt {retry_count}/{max_attempts}: {err}" + ) + print(msg, file=sys.stderr) if wait_before_retry is not None: time.sleep(wait_before_retry) retry_count += 1 - return test_func_ref(*args, **kwargs) + return obj(*args, **kwargs) return wrapper diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py index a7eb74080499..f976577653b2 100644 --- a/tests/lora/test_lora_layers_wanvace.py +++ b/tests/lora/test_lora_layers_wanvace.py @@ -46,6 +46,7 @@ @require_peft_backend @skip_mps +@is_flaky(max_attempts=10, description="very flaky class") class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = WanVACEPipeline scheduler_cls = FlowMatchEulerDiscreteScheduler @@ -217,6 +218,5 @@ def test_lora_exclude_modules_wanvace(self): "Lora outputs should match.", ) - @is_flaky def test_simple_inference_with_text_denoiser_lora_and_scale(self): super().test_simple_inference_with_text_denoiser_lora_and_scale()