Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,10 +994,10 @@ def summary_failures_short(tr):
config.option.tbstyle = orig_tbstyle


# Copied from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
# Adapted from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
"""
To decorate flaky tests. They will be retried on failures.
To decorate flaky tests (methods or entire classes). They will be retried on failures.

Args:
max_attempts (`int`, *optional*, defaults to 5):
Expand All @@ -1009,22 +1009,33 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, d
etc.)
"""

def decorator(test_func_ref):
@functools.wraps(test_func_ref)
def decorator(obj):
# If decorating a class, wrap each test method on it
if inspect.isclass(obj):
for attr_name, attr_value in list(obj.__dict__.items()):
if callable(attr_value) and attr_name.startswith("test"):
# recursively decorate the method
setattr(obj, attr_name, decorator(attr_value))
return obj

# Otherwise we're decorating a single test function / method
@functools.wraps(obj)
def wrapper(*args, **kwargs):
retry_count = 1

while retry_count < max_attempts:
try:
return test_func_ref(*args, **kwargs)

return obj(*args, **kwargs)
except Exception as err:
print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
msg = (
f"[FLAKY] {description or obj.__name__!r} "
f"failed on attempt {retry_count}/{max_attempts}: {err}"
)
print(msg, file=sys.stderr)
if wait_before_retry is not None:
time.sleep(wait_before_retry)
retry_count += 1

return test_func_ref(*args, **kwargs)
return obj(*args, **kwargs)

return wrapper

Expand Down
2 changes: 1 addition & 1 deletion tests/lora/test_lora_layers_wanvace.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

@require_peft_backend
@skip_mps
@is_flaky(max_attempts=10, description="very flaky class")
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
Expand Down Expand Up @@ -217,6 +218,5 @@ def test_lora_exclude_modules_wanvace(self):
"Lora outputs should match.",
)

@is_flaky
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
super().test_simple_inference_with_text_denoiser_lora_and_scale()