Skip to content

Commit 3cced85

Browse files
authored
Merge branch 'main' into bug-fix
2 parents f91cbd1 + 08f74a8 commit 3cced85

File tree

1 file changed

+80
-5
lines changed

1 file changed

+80
-5
lines changed

tests/remote/test_remote_decode.py

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from diffusers.utils.remote_utils import remote_decode
2525
from diffusers.utils.testing_utils import (
2626
enable_full_determinism,
27+
slow,
2728
torch_all_close,
2829
torch_device,
2930
)
@@ -32,6 +33,11 @@
3233

3334
enable_full_determinism()
3435

36+
ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
37+
ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
38+
ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
39+
ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
40+
3541

3642
class RemoteAutoencoderKLMixin:
3743
shape: Tuple[int, ...] = None
@@ -344,7 +350,7 @@ class RemoteAutoencoderKLSDv1Tests(
344350
512,
345351
512,
346352
)
347-
endpoint = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
353+
endpoint = ENDPOINT_SD_V1
348354
dtype = torch.float16
349355
scaling_factor = 0.18215
350356
shift_factor = None
@@ -368,7 +374,7 @@ class RemoteAutoencoderKLSDXLTests(
368374
1024,
369375
1024,
370376
)
371-
endpoint = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
377+
endpoint = ENDPOINT_SD_XL
372378
dtype = torch.float16
373379
scaling_factor = 0.13025
374380
shift_factor = None
@@ -392,7 +398,7 @@ class RemoteAutoencoderKLFluxTests(
392398
1024,
393399
1024,
394400
)
395-
endpoint = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
401+
endpoint = ENDPOINT_FLUX
396402
dtype = torch.bfloat16
397403
scaling_factor = 0.3611
398404
shift_factor = 0.1159
@@ -419,7 +425,7 @@ class RemoteAutoencoderKLFluxPackedTests(
419425
)
420426
height = 1024
421427
width = 1024
422-
endpoint = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
428+
endpoint = ENDPOINT_FLUX
423429
dtype = torch.bfloat16
424430
scaling_factor = 0.3611
425431
shift_factor = 0.1159
@@ -447,7 +453,7 @@ class RemoteAutoencoderKLHunyuanVideoTests(
447453
320,
448454
512,
449455
)
450-
endpoint = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
456+
endpoint = ENDPOINT_HUNYUAN_VIDEO
451457
dtype = torch.float16
452458
scaling_factor = 0.476986
453459
processor_cls = VideoProcessor
@@ -456,3 +462,72 @@ class RemoteAutoencoderKLHunyuanVideoTests(
456462
[149, 161, 168, 136, 150, 156, 129, 143, 149], dtype=torch.uint8
457463
)
458464
return_pt_slice = torch.tensor([0.1656, 0.2661, 0.3157, 0.0693, 0.1755, 0.2252, 0.0127, 0.1221, 0.1708])
465+
466+
467+
class RemoteAutoencoderKLSlowTestMixin:
468+
channels: int = 4
469+
endpoint: str = None
470+
dtype: torch.dtype = None
471+
scaling_factor: float = None
472+
shift_factor: float = None
473+
width: int = None
474+
height: int = None
475+
476+
def get_dummy_inputs(self):
477+
inputs = {
478+
"endpoint": self.endpoint,
479+
"scaling_factor": self.scaling_factor,
480+
"shift_factor": self.shift_factor,
481+
"height": self.height,
482+
"width": self.width,
483+
}
484+
return inputs
485+
486+
def test_multi_res(self):
487+
inputs = self.get_dummy_inputs()
488+
for height in {320, 512, 640, 704, 896, 1024, 1208, 1384, 1536, 1608, 1864, 2048}:
489+
for width in {320, 512, 640, 704, 896, 1024, 1208, 1384, 1536, 1608, 1864, 2048}:
490+
inputs["tensor"] = torch.randn(
491+
(1, self.channels, height // 8, width // 8),
492+
device=torch_device,
493+
dtype=self.dtype,
494+
generator=torch.Generator(torch_device).manual_seed(13),
495+
)
496+
inputs["height"] = height
497+
inputs["width"] = width
498+
output = remote_decode(output_type="pt", partial_postprocess=True, **inputs)
499+
output.save(f"test_multi_res_{height}_{width}.png")
500+
501+
502+
@slow
503+
class RemoteAutoencoderKLSDv1SlowTests(
504+
RemoteAutoencoderKLSlowTestMixin,
505+
unittest.TestCase,
506+
):
507+
endpoint = ENDPOINT_SD_V1
508+
dtype = torch.float16
509+
scaling_factor = 0.18215
510+
shift_factor = None
511+
512+
513+
@slow
514+
class RemoteAutoencoderKLSDXLSlowTests(
515+
RemoteAutoencoderKLSlowTestMixin,
516+
unittest.TestCase,
517+
):
518+
endpoint = ENDPOINT_SD_XL
519+
dtype = torch.float16
520+
scaling_factor = 0.13025
521+
shift_factor = None
522+
523+
524+
@slow
525+
class RemoteAutoencoderKLFluxSlowTests(
526+
RemoteAutoencoderKLSlowTestMixin,
527+
unittest.TestCase,
528+
):
529+
channels = 16
530+
endpoint = ENDPOINT_FLUX
531+
dtype = torch.bfloat16
532+
scaling_factor = 0.3611
533+
shift_factor = 0.1159

0 commit comments

Comments
 (0)