Skip to content

Commit 7e13c18

Browse files
committed
update
1 parent 71f34fc commit 7e13c18

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed

src/diffusers/utils/testing_utils.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,3 +1377,88 @@ def find_expectation(self, key: DeviceProperties = (None, None)) -> Any:
13771377

13781378
def __repr__(self):
13791379
return f"{self.data}"
1380+
1381+
1382+
def dynamic_slice_test(func):
1383+
"""
1384+
Decorator that injects an expected_slice parameter into a test function.
1385+
1386+
On the first run, it will capture the actual slice output and cache it.
1387+
On subsequent runs, it provides the cached slice as the expected slice.
1388+
1389+
Example:
1390+
```python
1391+
@dynamic_slice_test
1392+
def test_stable_diffusion_ddim(self, expected_slice=None):
1393+
# Run the pipeline
1394+
components = self.get_dummy_components()
1395+
sd_pipe = StableDiffusionPipeline(**components)
1396+
inputs = self.get_dummy_inputs("cpu")
1397+
image = sd_pipe(**inputs).images
1398+
image_slice = image[0, -3:, -3:, -1]
1399+
1400+
# If expected_slice is provided (from cache), assert against it
1401+
if expected_slice is not None:
1402+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
1403+
1404+
# Always return the current slice for caching
1405+
return image_slice
1406+
```
1407+
"""
1408+
# Check if the function has the expected_slice parameter
1409+
sig = inspect.signature(func)
1410+
if "expected_slice" not in sig.parameters:
1411+
raise ValueError("The decorated function must have an 'expected_slice' parameter")
1412+
1413+
@functools.wraps(func)
1414+
def wrapper(*args, **kwargs):
1415+
# Get the test name from pytest
1416+
# pytest sets this environment variable to the current test
1417+
test_name = os.environ.get("PYTEST_CURRENT_TEST", "")
1418+
if test_name:
1419+
# Format is: test_file.py::TestClass::test_method (call)
1420+
test_name = test_name.split(" ")[0]
1421+
else:
1422+
# Fallback if not running in pytest
1423+
test_name = f"{func.__module__}.{func.__qualname__}"
1424+
1425+
# Create a unique filename based on hardware details
1426+
device_props = get_device_properties()
1427+
device_str = f"{device_props[0]}{device_props[1] if device_props[1] is not None else ''}"
1428+
1429+
# Setup cache directory
1430+
cache_dir = os.environ.get("DIFFUSERS_TEST_CACHE_DIR", ".test_cache")
1431+
os.makedirs(cache_dir, exist_ok=True)
1432+
cache_path = os.path.join(cache_dir, f"{test_name}_{device_str}.npy")
1433+
1434+
# Check for cached expected slice
1435+
cached_slice = None
1436+
if os.path.exists(cache_path):
1437+
try:
1438+
cached_slice = np.load(cache_path)
1439+
print(f"Using cached slice from {cache_path}")
1440+
except Exception as e:
1441+
print(f"Error loading cached slice: {e}")
1442+
1443+
# Run the test function with the expected slice injected
1444+
kwargs["expected_slice"] = cached_slice
1445+
actual_slice = func(*args, **kwargs)
1446+
1447+
# If the function returned a slice and there's no cached slice yet, cache it
1448+
if actual_slice is not None and cached_slice is None:
1449+
# Convert torch tensor to numpy if needed
1450+
if hasattr(actual_slice, "detach") and hasattr(actual_slice, "cpu") and hasattr(actual_slice, "numpy"):
1451+
actual_slice_np = actual_slice.detach().cpu().numpy()
1452+
else:
1453+
actual_slice_np = actual_slice
1454+
1455+
# Save the slice
1456+
try:
1457+
np.save(cache_path, actual_slice_np)
1458+
print(f"Saved slice to cache: {cache_path}")
1459+
except Exception as e:
1460+
print(f"Error saving slice to cache: {e}")
1461+
1462+
return actual_slice
1463+
1464+
return wrapper

0 commit comments

Comments
 (0)