@@ -1377,3 +1377,88 @@ def find_expectation(self, key: DeviceProperties = (None, None)) -> Any:
13771377
13781378 def __repr__ (self ):
13791379 return f"{ self .data } "
1380+
1381+
1382+ def dynamic_slice_test (func ):
1383+ """
1384+ Decorator that injects an expected_slice parameter into a test function.
1385+
1386+ On the first run, it will capture the actual slice output and cache it.
1387+ On subsequent runs, it provides the cached slice as the expected slice.
1388+
1389+ Example:
1390+ ```python
1391+ @dynamic_slice_test
1392+ def test_stable_diffusion_ddim(self, expected_slice=None):
1393+ # Run the pipeline
1394+ components = self.get_dummy_components()
1395+ sd_pipe = StableDiffusionPipeline(**components)
1396+ inputs = self.get_dummy_inputs("cpu")
1397+ image = sd_pipe(**inputs).images
1398+ image_slice = image[0, -3:, -3:, -1]
1399+
1400+ # If expected_slice is provided (from cache), assert against it
1401+ if expected_slice is not None:
1402+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
1403+
1404+ # Always return the current slice for caching
1405+ return image_slice
1406+ ```
1407+ """
1408+ # Check if the function has the expected_slice parameter
1409+ sig = inspect .signature (func )
1410+ if "expected_slice" not in sig .parameters :
1411+ raise ValueError ("The decorated function must have an 'expected_slice' parameter" )
1412+
1413+ @functools .wraps (func )
1414+ def wrapper (* args , ** kwargs ):
1415+ # Get the test name from pytest
1416+ # pytest sets this environment variable to the current test
1417+ test_name = os .environ .get ("PYTEST_CURRENT_TEST" , "" )
1418+ if test_name :
1419+ # Format is: test_file.py::TestClass::test_method (call)
1420+ test_name = test_name .split (" " )[0 ]
1421+ else :
1422+ # Fallback if not running in pytest
1423+ test_name = f"{ func .__module__ } .{ func .__qualname__ } "
1424+
1425+ # Create a unique filename based on hardware details
1426+ device_props = get_device_properties ()
1427+ device_str = f"{ device_props [0 ]} { device_props [1 ] if device_props [1 ] is not None else '' } "
1428+
1429+ # Setup cache directory
1430+ cache_dir = os .environ .get ("DIFFUSERS_TEST_CACHE_DIR" , ".test_cache" )
1431+ os .makedirs (cache_dir , exist_ok = True )
1432+ cache_path = os .path .join (cache_dir , f"{ test_name } _{ device_str } .npy" )
1433+
1434+ # Check for cached expected slice
1435+ cached_slice = None
1436+ if os .path .exists (cache_path ):
1437+ try :
1438+ cached_slice = np .load (cache_path )
1439+ print (f"Using cached slice from { cache_path } " )
1440+ except Exception as e :
1441+ print (f"Error loading cached slice: { e } " )
1442+
1443+ # Run the test function with the expected slice injected
1444+ kwargs ["expected_slice" ] = cached_slice
1445+ actual_slice = func (* args , ** kwargs )
1446+
1447+ # If the function returned a slice and there's no cached slice yet, cache it
1448+ if actual_slice is not None and cached_slice is None :
1449+ # Convert torch tensor to numpy if needed
1450+ if hasattr (actual_slice , "detach" ) and hasattr (actual_slice , "cpu" ) and hasattr (actual_slice , "numpy" ):
1451+ actual_slice_np = actual_slice .detach ().cpu ().numpy ()
1452+ else :
1453+ actual_slice_np = actual_slice
1454+
1455+ # Save the slice
1456+ try :
1457+ np .save (cache_path , actual_slice_np )
1458+ print (f"Saved slice to cache: { cache_path } " )
1459+ except Exception as e :
1460+ print (f"Error saving slice to cache: { e } " )
1461+
1462+ return actual_slice
1463+
1464+ return wrapper
0 commit comments