diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py index 3eefa97663e6..e50f5316da60 100644 --- a/tests/lora/test_lora_layers_sd.py +++ b/tests/lora/test_lora_layers_sd.py @@ -33,6 +33,7 @@ ) from diffusers.utils.import_utils import is_accelerate_available from diffusers.utils.testing_utils import ( + Expectations, backend_empty_cache, load_image, nightly, @@ -455,11 +456,54 @@ def test_vanilla_funetuning(self): images = pipe("A pokemon with blue eyes.", output_type="np", generator=generator, num_inference_steps=2).images - images = images[0, -3:, -3:, -1].flatten() - - expected = np.array([0.7406, 0.699, 0.5963, 0.7493, 0.7045, 0.6096, 0.6886, 0.6388, 0.583]) + image_slice = images[0, -3:, -3:, -1].flatten() + + expected_slices = Expectations( + { + ("xpu", 3): np.array( + [ + 0.6544, + 0.6127, + 0.5397, + 0.6845, + 0.6047, + 0.5469, + 0.6349, + 0.5906, + 0.5382, + ] + ), + ("cuda", 7): np.array( + [ + 0.7406, + 0.699, + 0.5963, + 0.7493, + 0.7045, + 0.6096, + 0.6886, + 0.6388, + 0.583, + ] + ), + ("cuda", 8): np.array( + [ + 0.6542, + 0.61253, + 0.5396, + 0.6843, + 0.6044, + 0.5468, + 0.6349, + 0.5905, + 0.5381, + ] + ), + } + ) + expected_slice = expected_slices.get_expectation() - max_diff = numpy_cosine_similarity_distance(expected, images) + max_diff = numpy_cosine_similarity_distance(expected_slice, image_slice) assert max_diff < 1e-4 pipe.unload_lora_weights() diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index be5245796b35..a950de142740 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1347,7 +1347,7 @@ def test_components_function(self): @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU") @require_accelerator - def test_float16_inference(self, expected_max_diff=6e-2): + def test_float16_inference(self, expected_max_diff=5e-2): components = self.get_dummy_components() pipe = self.pipeline_class(**components) for component in pipe.components.values():