Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 80 additions & 5 deletions tests/remote/test_remote_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from diffusers.utils.remote_utils import remote_decode
from diffusers.utils.testing_utils import (
enable_full_determinism,
slow,
torch_all_close,
torch_device,
)
Expand All @@ -32,6 +33,11 @@

enable_full_determinism()

ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"


class RemoteAutoencoderKLMixin:
shape: Tuple[int, ...] = None
Expand Down Expand Up @@ -344,7 +350,7 @@ class RemoteAutoencoderKLSDv1Tests(
512,
512,
)
endpoint = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
endpoint = ENDPOINT_SD_V1
dtype = torch.float16
scaling_factor = 0.18215
shift_factor = None
Expand All @@ -368,7 +374,7 @@ class RemoteAutoencoderKLSDXLTests(
1024,
1024,
)
endpoint = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
endpoint = ENDPOINT_SD_XL
dtype = torch.float16
scaling_factor = 0.13025
shift_factor = None
Expand All @@ -392,7 +398,7 @@ class RemoteAutoencoderKLFluxTests(
1024,
1024,
)
endpoint = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
endpoint = ENDPOINT_FLUX
dtype = torch.bfloat16
scaling_factor = 0.3611
shift_factor = 0.1159
Expand All @@ -419,7 +425,7 @@ class RemoteAutoencoderKLFluxPackedTests(
)
height = 1024
width = 1024
endpoint = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
endpoint = ENDPOINT_FLUX
dtype = torch.bfloat16
scaling_factor = 0.3611
shift_factor = 0.1159
Expand Down Expand Up @@ -447,7 +453,7 @@ class RemoteAutoencoderKLHunyuanVideoTests(
320,
512,
)
endpoint = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
endpoint = ENDPOINT_HUNYUAN_VIDEO
dtype = torch.float16
scaling_factor = 0.476986
processor_cls = VideoProcessor
Expand All @@ -456,3 +462,72 @@ class RemoteAutoencoderKLHunyuanVideoTests(
[149, 161, 168, 136, 150, 156, 129, 143, 149], dtype=torch.uint8
)
return_pt_slice = torch.tensor([0.1656, 0.2661, 0.3157, 0.0693, 0.1755, 0.2252, 0.0127, 0.1221, 0.1708])


class RemoteAutoencoderKLSlowTestMixin:
channels: int = 4
endpoint: str = None
dtype: torch.dtype = None
scaling_factor: float = None
shift_factor: float = None
width: int = None
height: int = None

def get_dummy_inputs(self):
inputs = {
"endpoint": self.endpoint,
"scaling_factor": self.scaling_factor,
"shift_factor": self.shift_factor,
"height": self.height,
"width": self.width,
}
return inputs

def test_multi_res(self):
inputs = self.get_dummy_inputs()
for height in {320, 512, 640, 704, 896, 1024, 1208, 1384, 1536, 1608, 1864, 2048}:
for width in {320, 512, 640, 704, 896, 1024, 1208, 1384, 1536, 1608, 1864, 2048}:
inputs["tensor"] = torch.randn(
(1, self.channels, height // 8, width // 8),
device=torch_device,
dtype=self.dtype,
generator=torch.Generator(torch_device).manual_seed(13),
)
inputs["height"] = height
inputs["width"] = width
output = remote_decode(output_type="pt", partial_postprocess=True, **inputs)
output.save(f"test_multi_res_{height}_{width}.png")


@slow
class RemoteAutoencoderKLSDv1SlowTests(
RemoteAutoencoderKLSlowTestMixin,
unittest.TestCase,
):
endpoint = ENDPOINT_SD_V1
dtype = torch.float16
scaling_factor = 0.18215
shift_factor = None


@slow
class RemoteAutoencoderKLSDXLSlowTests(
RemoteAutoencoderKLSlowTestMixin,
unittest.TestCase,
):
endpoint = ENDPOINT_SD_XL
dtype = torch.float16
scaling_factor = 0.13025
shift_factor = None


@slow
class RemoteAutoencoderKLFluxSlowTests(
RemoteAutoencoderKLSlowTestMixin,
unittest.TestCase,
):
channels = 16
endpoint = ENDPOINT_FLUX
dtype = torch.bfloat16
scaling_factor = 0.3611
shift_factor = 0.1159