2626 DECODE_ENDPOINT_HUNYUAN_VIDEO ,
2727 DECODE_ENDPOINT_SD_V1 ,
2828 DECODE_ENDPOINT_SD_XL ,
29+ DECODE_ENDPOINT_WAN_2_1 ,
2930)
3031from diffusers .utils .remote_utils import (
3132 remote_decode ,
@@ -176,18 +177,6 @@ def test_output_type_pt_partial_postprocess_return_type_pt(self):
176177 f"{ output_slice } " ,
177178 )
178179
179- def test_do_scaling_deprecation (self ):
180- inputs = self .get_dummy_inputs ()
181- inputs .pop ("scaling_factor" , None )
182- inputs .pop ("shift_factor" , None )
183- with self .assertWarns (FutureWarning ) as warning :
184- _ = remote_decode (output_type = "pt" , partial_postprocess = True , ** inputs )
185- self .assertEqual (
186- str (warning .warnings [0 ].message ),
187- "`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required." ,
188- str (warning .warnings [0 ].message ),
189- )
190-
191180 def test_input_tensor_type_base64_deprecation (self ):
192181 inputs = self .get_dummy_inputs ()
193182 with self .assertWarns (FutureWarning ) as warning :
@@ -209,7 +198,7 @@ def test_output_tensor_type_base64_deprecation(self):
209198 )
210199
211200
212- class RemoteAutoencoderKLHunyuanVideoMixin (RemoteAutoencoderKLMixin ):
201+ class RemoteAutoencoderKLVideoMixin (RemoteAutoencoderKLMixin ):
213202 def test_no_scaling (self ):
214203 inputs = self .get_dummy_inputs ()
215204 if inputs ["scaling_factor" ] is not None :
@@ -221,7 +210,6 @@ def test_no_scaling(self):
221210 processor = self .processor_cls ()
222211 output = remote_decode (
223212 output_type = "pt" ,
224- # required for now, will be removed in next update
225213 do_scaling = False ,
226214 processor = processor ,
227215 ** inputs ,
@@ -337,6 +325,8 @@ def test_output_type_mp4(self):
337325 inputs = self .get_dummy_inputs ()
338326 output = remote_decode (output_type = "mp4" , return_type = "mp4" , ** inputs )
339327 self .assertTrue (isinstance (output , bytes ), f"Expected `bytes` output, got { type (output )} " )
328+ with open ("test.mp4" , "wb" ) as f :
329+ f .write (output )
340330
341331
342332class RemoteAutoencoderKLSDv1Tests (
@@ -442,7 +432,7 @@ class RemoteAutoencoderKLFluxPackedTests(
442432
443433
444434class RemoteAutoencoderKLHunyuanVideoTests (
445- RemoteAutoencoderKLHunyuanVideoMixin ,
435+ RemoteAutoencoderKLVideoMixin ,
446436 unittest .TestCase ,
447437):
448438 shape = (
@@ -467,6 +457,31 @@ class RemoteAutoencoderKLHunyuanVideoTests(
467457 return_pt_slice = torch .tensor ([0.1656 , 0.2661 , 0.3157 , 0.0693 , 0.1755 , 0.2252 , 0.0127 , 0.1221 , 0.1708 ])
468458
469459
460+ class RemoteAutoencoderKLWanTests (
461+ RemoteAutoencoderKLVideoMixin ,
462+ unittest .TestCase ,
463+ ):
464+ shape = (
465+ 1 ,
466+ 16 ,
467+ 3 ,
468+ 40 ,
469+ 64 ,
470+ )
471+ out_hw = (
472+ 320 ,
473+ 512 ,
474+ )
475+ endpoint = DECODE_ENDPOINT_WAN_2_1
476+ dtype = torch .float16
477+ processor_cls = VideoProcessor
478+ output_pt_slice = torch .tensor ([203 , 174 , 178 , 204 , 171 , 177 , 209 , 183 , 182 ], dtype = torch .uint8 )
479+ partial_postprocess_return_pt_slice = torch .tensor (
480+ [206 , 209 , 221 , 202 , 199 , 222 , 207 , 210 , 217 ], dtype = torch .uint8
481+ )
482+ return_pt_slice = torch .tensor ([0.6196 , 0.6382 , 0.7310 , 0.5869 , 0.5625 , 0.7373 , 0.6240 , 0.6465 , 0.7002 ])
483+
484+
470485class RemoteAutoencoderKLSlowTestMixin :
471486 channels : int = 4
472487 endpoint : str = None
0 commit comments