diff --git a/tests/remote/test_remote_decode.py b/tests/remote/test_remote_decode.py index 4b8884607459..11f9c24d16f6 100644 --- a/tests/remote/test_remote_decode.py +++ b/tests/remote/test_remote_decode.py @@ -24,6 +24,7 @@ from diffusers.utils.remote_utils import remote_decode from diffusers.utils.testing_utils import ( enable_full_determinism, + slow, torch_all_close, torch_device, ) @@ -32,6 +33,11 @@ enable_full_determinism() +ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/" +ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/" +ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/" +ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/" + class RemoteAutoencoderKLMixin: shape: Tuple[int, ...] = None @@ -344,7 +350,7 @@ class RemoteAutoencoderKLSDv1Tests( 512, 512, ) - endpoint = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/" + endpoint = ENDPOINT_SD_V1 dtype = torch.float16 scaling_factor = 0.18215 shift_factor = None @@ -368,7 +374,7 @@ class RemoteAutoencoderKLSDXLTests( 1024, 1024, ) - endpoint = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/" + endpoint = ENDPOINT_SD_XL dtype = torch.float16 scaling_factor = 0.13025 shift_factor = None @@ -392,7 +398,7 @@ class RemoteAutoencoderKLFluxTests( 1024, 1024, ) - endpoint = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/" + endpoint = ENDPOINT_FLUX dtype = torch.bfloat16 scaling_factor = 0.3611 shift_factor = 0.1159 @@ -419,7 +425,7 @@ class RemoteAutoencoderKLFluxPackedTests( ) height = 1024 width = 1024 - endpoint = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/" + endpoint = ENDPOINT_FLUX dtype = torch.bfloat16 scaling_factor = 0.3611 shift_factor = 0.1159 @@ -447,7 +453,7 @@ class RemoteAutoencoderKLHunyuanVideoTests( 320, 512, ) - endpoint = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/" + endpoint = ENDPOINT_HUNYUAN_VIDEO dtype = torch.float16 scaling_factor = 0.476986 processor_cls = VideoProcessor @@ -456,3 +462,72 @@ class RemoteAutoencoderKLHunyuanVideoTests( [149, 161, 168, 136, 150, 156, 129, 143, 149], dtype=torch.uint8 ) return_pt_slice = torch.tensor([0.1656, 0.2661, 0.3157, 0.0693, 0.1755, 0.2252, 0.0127, 0.1221, 0.1708]) + + +class RemoteAutoencoderKLSlowTestMixin: + channels: int = 4 + endpoint: str = None + dtype: torch.dtype = None + scaling_factor: float = None + shift_factor: float = None + width: int = None + height: int = None + + def get_dummy_inputs(self): + inputs = { + "endpoint": self.endpoint, + "scaling_factor": self.scaling_factor, + "shift_factor": self.shift_factor, + "height": self.height, + "width": self.width, + } + return inputs + + def test_multi_res(self): + inputs = self.get_dummy_inputs() + for height in {320, 512, 640, 704, 896, 1024, 1208, 1384, 1536, 1608, 1864, 2048}: + for width in {320, 512, 640, 704, 896, 1024, 1208, 1384, 1536, 1608, 1864, 2048}: + inputs["tensor"] = torch.randn( + (1, self.channels, height // 8, width // 8), + device=torch_device, + dtype=self.dtype, + generator=torch.Generator(torch_device).manual_seed(13), + ) + inputs["height"] = height + inputs["width"] = width + output = remote_decode(output_type="pt", partial_postprocess=True, **inputs) + output.save(f"test_multi_res_{height}_{width}.png") + + +@slow +class RemoteAutoencoderKLSDv1SlowTests( + RemoteAutoencoderKLSlowTestMixin, + unittest.TestCase, +): + endpoint = ENDPOINT_SD_V1 + dtype = torch.float16 + scaling_factor = 0.18215 + shift_factor = None + + +@slow +class RemoteAutoencoderKLSDXLSlowTests( + RemoteAutoencoderKLSlowTestMixin, + unittest.TestCase, +): + endpoint = ENDPOINT_SD_XL + dtype = torch.float16 + scaling_factor = 0.13025 + shift_factor = None + + +@slow +class RemoteAutoencoderKLFluxSlowTests( + RemoteAutoencoderKLSlowTestMixin, + unittest.TestCase, +): + channels = 16 + endpoint = ENDPOINT_FLUX + dtype = torch.bfloat16 + scaling_factor = 0.3611 + shift_factor = 0.1159