2424from diffusers .utils .remote_utils import remote_decode
2525from diffusers .utils .testing_utils import (
2626 enable_full_determinism ,
27+ slow ,
2728 torch_all_close ,
2829 torch_device ,
2930)
3233
3334enable_full_determinism ()
3435
36+ ENDPOINT_SD_V1 = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
37+ ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
38+ ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
39+ ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
40+
3541
3642class RemoteAutoencoderKLMixin :
3743 shape : Tuple [int , ...] = None
@@ -344,7 +350,7 @@ class RemoteAutoencoderKLSDv1Tests(
344350 512 ,
345351 512 ,
346352 )
347- endpoint = "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/"
353+ endpoint = ENDPOINT_SD_V1
348354 dtype = torch .float16
349355 scaling_factor = 0.18215
350356 shift_factor = None
@@ -368,7 +374,7 @@ class RemoteAutoencoderKLSDXLTests(
368374 1024 ,
369375 1024 ,
370376 )
371- endpoint = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
377+ endpoint = ENDPOINT_SD_XL
372378 dtype = torch .float16
373379 scaling_factor = 0.13025
374380 shift_factor = None
@@ -392,7 +398,7 @@ class RemoteAutoencoderKLFluxTests(
392398 1024 ,
393399 1024 ,
394400 )
395- endpoint = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
401+ endpoint = ENDPOINT_FLUX
396402 dtype = torch .bfloat16
397403 scaling_factor = 0.3611
398404 shift_factor = 0.1159
@@ -419,7 +425,7 @@ class RemoteAutoencoderKLFluxPackedTests(
419425 )
420426 height = 1024
421427 width = 1024
422- endpoint = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
428+ endpoint = ENDPOINT_FLUX
423429 dtype = torch .bfloat16
424430 scaling_factor = 0.3611
425431 shift_factor = 0.1159
@@ -447,7 +453,7 @@ class RemoteAutoencoderKLHunyuanVideoTests(
447453 320 ,
448454 512 ,
449455 )
450- endpoint = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
456+ endpoint = ENDPOINT_HUNYUAN_VIDEO
451457 dtype = torch .float16
452458 scaling_factor = 0.476986
453459 processor_cls = VideoProcessor
@@ -456,3 +462,72 @@ class RemoteAutoencoderKLHunyuanVideoTests(
456462 [149 , 161 , 168 , 136 , 150 , 156 , 129 , 143 , 149 ], dtype = torch .uint8
457463 )
458464 return_pt_slice = torch .tensor ([0.1656 , 0.2661 , 0.3157 , 0.0693 , 0.1755 , 0.2252 , 0.0127 , 0.1221 , 0.1708 ])
465+
466+
467+ class RemoteAutoencoderKLSlowTestMixin :
468+ channels : int = 4
469+ endpoint : str = None
470+ dtype : torch .dtype = None
471+ scaling_factor : float = None
472+ shift_factor : float = None
473+ width : int = None
474+ height : int = None
475+
476+ def get_dummy_inputs (self ):
477+ inputs = {
478+ "endpoint" : self .endpoint ,
479+ "scaling_factor" : self .scaling_factor ,
480+ "shift_factor" : self .shift_factor ,
481+ "height" : self .height ,
482+ "width" : self .width ,
483+ }
484+ return inputs
485+
486+ def test_multi_res (self ):
487+ inputs = self .get_dummy_inputs ()
488+ for height in {320 , 512 , 640 , 704 , 896 , 1024 , 1208 , 1384 , 1536 , 1608 , 1864 , 2048 }:
489+ for width in {320 , 512 , 640 , 704 , 896 , 1024 , 1208 , 1384 , 1536 , 1608 , 1864 , 2048 }:
490+ inputs ["tensor" ] = torch .randn (
491+ (1 , self .channels , height // 8 , width // 8 ),
492+ device = torch_device ,
493+ dtype = self .dtype ,
494+ generator = torch .Generator (torch_device ).manual_seed (13 ),
495+ )
496+ inputs ["height" ] = height
497+ inputs ["width" ] = width
498+ output = remote_decode (output_type = "pt" , partial_postprocess = True , ** inputs )
499+ output .save (f"test_multi_res_{ height } _{ width } .png" )
500+
501+
502+ @slow
503+ class RemoteAutoencoderKLSDv1SlowTests (
504+ RemoteAutoencoderKLSlowTestMixin ,
505+ unittest .TestCase ,
506+ ):
507+ endpoint = ENDPOINT_SD_V1
508+ dtype = torch .float16
509+ scaling_factor = 0.18215
510+ shift_factor = None
511+
512+
513+ @slow
514+ class RemoteAutoencoderKLSDXLSlowTests (
515+ RemoteAutoencoderKLSlowTestMixin ,
516+ unittest .TestCase ,
517+ ):
518+ endpoint = ENDPOINT_SD_XL
519+ dtype = torch .float16
520+ scaling_factor = 0.13025
521+ shift_factor = None
522+
523+
524+ @slow
525+ class RemoteAutoencoderKLFluxSlowTests (
526+ RemoteAutoencoderKLSlowTestMixin ,
527+ unittest .TestCase ,
528+ ):
529+ channels = 16
530+ endpoint = ENDPOINT_FLUX
531+ dtype = torch .bfloat16
532+ scaling_factor = 0.3611
533+ shift_factor = 0.1159
0 commit comments