2424from  diffusers .utils .remote_utils  import  remote_decode 
2525from  diffusers .utils .testing_utils  import  (
2626    enable_full_determinism ,
27+     slow ,
2728    torch_all_close ,
2829    torch_device ,
2930)
3233
3334enable_full_determinism ()
3435
36+ ENDPOINT_SD_V1  =  "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/" 
37+ ENDPOINT_SD_XL  =  "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/" 
38+ ENDPOINT_FLUX  =  "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/" 
39+ ENDPOINT_HUNYUAN_VIDEO  =  "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/" 
40+ 
3541
3642class  RemoteAutoencoderKLMixin :
3743    shape : Tuple [int , ...] =  None 
@@ -344,7 +350,7 @@ class RemoteAutoencoderKLSDv1Tests(
344350        512 ,
345351        512 ,
346352    )
347-     endpoint  =  "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/" 
353+     endpoint  =  ENDPOINT_SD_V1 
348354    dtype  =  torch .float16 
349355    scaling_factor  =  0.18215 
350356    shift_factor  =  None 
@@ -368,7 +374,7 @@ class RemoteAutoencoderKLSDXLTests(
368374        1024 ,
369375        1024 ,
370376    )
371-     endpoint  =  "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/" 
377+     endpoint  =  ENDPOINT_SD_XL 
372378    dtype  =  torch .float16 
373379    scaling_factor  =  0.13025 
374380    shift_factor  =  None 
@@ -392,7 +398,7 @@ class RemoteAutoencoderKLFluxTests(
392398        1024 ,
393399        1024 ,
394400    )
395-     endpoint  =  "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/" 
401+     endpoint  =  ENDPOINT_FLUX 
396402    dtype  =  torch .bfloat16 
397403    scaling_factor  =  0.3611 
398404    shift_factor  =  0.1159 
@@ -419,7 +425,7 @@ class RemoteAutoencoderKLFluxPackedTests(
419425    )
420426    height  =  1024 
421427    width  =  1024 
422-     endpoint  =  "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/" 
428+     endpoint  =  ENDPOINT_FLUX 
423429    dtype  =  torch .bfloat16 
424430    scaling_factor  =  0.3611 
425431    shift_factor  =  0.1159 
@@ -447,7 +453,7 @@ class RemoteAutoencoderKLHunyuanVideoTests(
447453        320 ,
448454        512 ,
449455    )
450-     endpoint  =  "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/" 
456+     endpoint  =  ENDPOINT_HUNYUAN_VIDEO 
451457    dtype  =  torch .float16 
452458    scaling_factor  =  0.476986 
453459    processor_cls  =  VideoProcessor 
@@ -456,3 +462,72 @@ class RemoteAutoencoderKLHunyuanVideoTests(
456462        [149 , 161 , 168 , 136 , 150 , 156 , 129 , 143 , 149 ], dtype = torch .uint8 
457463    )
458464    return_pt_slice  =  torch .tensor ([0.1656 , 0.2661 , 0.3157 , 0.0693 , 0.1755 , 0.2252 , 0.0127 , 0.1221 , 0.1708 ])
465+ 
466+ 
467+ class  RemoteAutoencoderKLSlowTestMixin :
468+     channels : int  =  4 
469+     endpoint : str  =  None 
470+     dtype : torch .dtype  =  None 
471+     scaling_factor : float  =  None 
472+     shift_factor : float  =  None 
473+     width : int  =  None 
474+     height : int  =  None 
475+ 
476+     def  get_dummy_inputs (self ):
477+         inputs  =  {
478+             "endpoint" : self .endpoint ,
479+             "scaling_factor" : self .scaling_factor ,
480+             "shift_factor" : self .shift_factor ,
481+             "height" : self .height ,
482+             "width" : self .width ,
483+         }
484+         return  inputs 
485+ 
486+     def  test_multi_res (self ):
487+         inputs  =  self .get_dummy_inputs ()
488+         for  height  in  {320 , 512 , 640 , 704 , 896 , 1024 , 1208 , 1384 , 1536 , 1608 , 1864 , 2048 }:
489+             for  width  in  {320 , 512 , 640 , 704 , 896 , 1024 , 1208 , 1384 , 1536 , 1608 , 1864 , 2048 }:
490+                 inputs ["tensor" ] =  torch .randn (
491+                     (1 , self .channels , height  //  8 , width  //  8 ),
492+                     device = torch_device ,
493+                     dtype = self .dtype ,
494+                     generator = torch .Generator (torch_device ).manual_seed (13 ),
495+                 )
496+                 inputs ["height" ] =  height 
497+                 inputs ["width" ] =  width 
498+                 output  =  remote_decode (output_type = "pt" , partial_postprocess = True , ** inputs )
499+                 output .save (f"test_multi_res_{ height } { width }  )
500+ 
501+ 
502+ @slow  
503+ class  RemoteAutoencoderKLSDv1SlowTests (
504+     RemoteAutoencoderKLSlowTestMixin ,
505+     unittest .TestCase ,
506+ ):
507+     endpoint  =  ENDPOINT_SD_V1 
508+     dtype  =  torch .float16 
509+     scaling_factor  =  0.18215 
510+     shift_factor  =  None 
511+ 
512+ 
513+ @slow  
514+ class  RemoteAutoencoderKLSDXLSlowTests (
515+     RemoteAutoencoderKLSlowTestMixin ,
516+     unittest .TestCase ,
517+ ):
518+     endpoint  =  ENDPOINT_SD_XL 
519+     dtype  =  torch .float16 
520+     scaling_factor  =  0.13025 
521+     shift_factor  =  None 
522+ 
523+ 
524+ @slow  
525+ class  RemoteAutoencoderKLFluxSlowTests (
526+     RemoteAutoencoderKLSlowTestMixin ,
527+     unittest .TestCase ,
528+ ):
529+     channels  =  16 
530+     endpoint  =  ENDPOINT_FLUX 
531+     dtype  =  torch .bfloat16 
532+     scaling_factor  =  0.3611 
533+     shift_factor  =  0.1159 
0 commit comments