@@ -200,6 +200,214 @@ def test_output_tensor_type_base64_deprecation(self):
200200 )
201201
202202
203+ class RemoteAutoencoderKLHunyuanVideoMixin :
204+ shape : Tuple [int , ...] = None
205+ out_hw : Tuple [int , int ] = None
206+ endpoint : str = None
207+ dtype : torch .dtype = None
208+ scaling_factor : float = None
209+ shift_factor : float = None
210+ processor_cls : Union [VaeImageProcessor , VideoProcessor ] = None
211+ output_pil_slice : torch .Tensor = None
212+ output_pt_slice : torch .Tensor = None
213+ partial_postprocess_return_pt_slice : torch .Tensor = None
214+ return_pt_slice : torch .Tensor = None
215+ width : int = None
216+ height : int = None
217+
218+ def get_dummy_inputs (self ):
219+ inputs = {
220+ "endpoint" : self .endpoint ,
221+ "tensor" : torch .randn (
222+ self .shape ,
223+ device = torch_device ,
224+ dtype = self .dtype ,
225+ generator = torch .Generator (torch_device ).manual_seed (13 ),
226+ ),
227+ "scaling_factor" : self .scaling_factor ,
228+ "shift_factor" : self .shift_factor ,
229+ "height" : self .height ,
230+ "width" : self .width ,
231+ }
232+ return inputs
233+
234+ def test_no_scaling (self ):
235+ inputs = self .get_dummy_inputs ()
236+ if inputs ["scaling_factor" ] is not None :
237+ inputs ["tensor" ] = inputs ["tensor" ] / inputs ["scaling_factor" ]
238+ inputs ["scaling_factor" ] = None
239+ if inputs ["shift_factor" ] is not None :
240+ inputs ["tensor" ] = inputs ["tensor" ] + inputs ["shift_factor" ]
241+ inputs ["shift_factor" ] = None
242+ processor = self .processor_cls ()
243+ output = remote_decode (
244+ output_type = "pt" ,
245+ # required for now, will be removed in next update
246+ do_scaling = False ,
247+ processor = processor ,
248+ ** inputs ,
249+ )
250+ self .assertTrue (
251+ isinstance (output , list ) and isinstance (output [0 ], PIL .Image .Image ),
252+ f"Expected `List[PIL.Image.Image]` output, got { type (output )} " ,
253+ )
254+ self .assertEqual (
255+ output [0 ].height , self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} , got { output [0 ].height } "
256+ )
257+ self .assertEqual (
258+ output [0 ].width , self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} , got { output [0 ].width } "
259+ )
260+ output_slice = torch .from_numpy (np .array (output [0 ])[0 , - 3 :, - 3 :].flatten ())
261+ self .assertTrue (
262+ torch_all_close (output_slice , self .output_pt_slice .to (output_slice .dtype ), rtol = 1 , atol = 1 ),
263+ f"{ output_slice } " ,
264+ )
265+
266+ def test_output_type_pt (self ):
267+ inputs = self .get_dummy_inputs ()
268+ processor = self .processor_cls ()
269+ output = remote_decode (output_type = "pt" , processor = processor , ** inputs )
270+ self .assertTrue (
271+ isinstance (output , list ) and isinstance (output [0 ], PIL .Image .Image ),
272+ f"Expected `List[PIL.Image.Image]` output, got { type (output )} " ,
273+ )
274+ self .assertEqual (
275+ output [0 ].height , self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} , got { output [0 ].height } "
276+ )
277+ self .assertEqual (
278+ output [0 ].width , self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} , got { output [0 ].width } "
279+ )
280+ output_slice = torch .from_numpy (np .array (output [0 ])[0 , - 3 :, - 3 :].flatten ())
281+ self .assertTrue (
282+ torch_all_close (output_slice , self .output_pt_slice .to (output_slice .dtype ), rtol = 1 , atol = 1 ),
283+ f"{ output_slice } " ,
284+ )
285+
286+ # output is visually the same, slice is flaky?
287+ def test_output_type_pil (self ):
288+ inputs = self .get_dummy_inputs ()
289+ processor = self .processor_cls ()
290+ output = remote_decode (output_type = "pil" , processor = processor , ** inputs )
291+ self .assertTrue (
292+ isinstance (output , list ) and isinstance (output [0 ], PIL .Image .Image ),
293+ f"Expected `List[PIL.Image.Image]` output, got { type (output )} " ,
294+ )
295+ self .assertEqual (
296+ output [0 ].height , self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} , got { output [0 ].height } "
297+ )
298+ self .assertEqual (
299+ output [0 ].width , self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} , got { output [0 ].width } "
300+ )
301+
302+ def test_output_type_pil_image_format (self ):
303+ inputs = self .get_dummy_inputs ()
304+ processor = self .processor_cls ()
305+ output = remote_decode (output_type = "pil" , processor = processor , image_format = "png" , ** inputs )
306+ self .assertTrue (
307+ isinstance (output , list ) and isinstance (output [0 ], PIL .Image .Image ),
308+ f"Expected `List[PIL.Image.Image]` output, got { type (output )} " ,
309+ )
310+ self .assertEqual (
311+ output [0 ].height , self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} , got { output [0 ].height } "
312+ )
313+ self .assertEqual (
314+ output [0 ].width , self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} , got { output [0 ].width } "
315+ )
316+ output_slice = torch .from_numpy (np .array (output [0 ])[0 , - 3 :, - 3 :].flatten ())
317+ self .assertTrue (
318+ torch_all_close (output_slice , self .output_pt_slice .to (output_slice .dtype ), rtol = 1 , atol = 1 ),
319+ f"{ output_slice } " ,
320+ )
321+
322+ def test_output_type_pt_partial_postprocess (self ):
323+ inputs = self .get_dummy_inputs ()
324+ output = remote_decode (output_type = "pt" , partial_postprocess = True , ** inputs )
325+ self .assertTrue (
326+ isinstance (output , list ) and isinstance (output [0 ], PIL .Image .Image ),
327+ f"Expected `List[PIL.Image.Image]` output, got { type (output )} " ,
328+ )
329+ self .assertEqual (
330+ output [0 ].height , self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} , got { output [0 ].height } "
331+ )
332+ self .assertEqual (
333+ output [0 ].width , self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} , got { output [0 ].width } "
334+ )
335+ output_slice = torch .from_numpy (np .array (output [0 ])[0 , - 3 :, - 3 :].flatten ())
336+ self .assertTrue (
337+ torch_all_close (output_slice , self .output_pt_slice .to (output_slice .dtype ), rtol = 1 , atol = 1 ),
338+ f"{ output_slice } " ,
339+ )
340+
341+ def test_output_type_pt_return_type_pt (self ):
342+ inputs = self .get_dummy_inputs ()
343+ output = remote_decode (output_type = "pt" , return_type = "pt" , ** inputs )
344+ self .assertTrue (isinstance (output , torch .Tensor ), f"Expected `torch.Tensor` output, got { type (output )} " )
345+ self .assertEqual (
346+ output .shape [3 ], self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} , got { output .shape [2 ]} "
347+ )
348+ self .assertEqual (
349+ output .shape [4 ], self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} , got { output .shape [3 ]} "
350+ )
351+ output_slice = output [0 , 0 , 0 , - 3 :, - 3 :].flatten ()
352+ self .assertTrue (
353+ torch_all_close (output_slice , self .return_pt_slice .to (output_slice .dtype ), rtol = 1e-3 , atol = 1e-3 ),
354+ f"{ output_slice } " ,
355+ )
356+
357+ def test_output_type_pt_partial_postprocess_return_type_pt (self ):
358+ inputs = self .get_dummy_inputs ()
359+ output = remote_decode (output_type = "pt" , partial_postprocess = True , return_type = "pt" , ** inputs )
360+ self .assertTrue (isinstance (output , torch .Tensor ), f"Expected `torch.Tensor` output, got { type (output )} " )
361+ self .assertEqual (
362+ output .shape [1 ], self .out_hw [0 ], f"Expected image height { self .out_hw [0 ]} , got { output .shape [1 ]} "
363+ )
364+ self .assertEqual (
365+ output .shape [2 ], self .out_hw [1 ], f"Expected image width { self .out_hw [0 ]} , got { output .shape [2 ]} "
366+ )
367+ output_slice = output [0 , - 3 :, - 3 :, 0 ].flatten ().cpu ()
368+ self .assertTrue (
369+ torch_all_close (output_slice , self .partial_postprocess_return_pt_slice .to (output_slice .dtype ), rtol = 1e-2 ),
370+ f"{ output_slice } " ,
371+ )
372+
373+ def test_output_type_mp4 (self ):
374+ inputs = self .get_dummy_inputs ()
375+ output = remote_decode (output_type = "mp4" , return_type = "mp4" , ** inputs )
376+ self .assertTrue (isinstance (output , bytes ), f"Expected `bytes` output, got { type (output )} " )
377+
378+ def test_do_scaling_deprecation (self ):
379+ inputs = self .get_dummy_inputs ()
380+ inputs .pop ("scaling_factor" , None )
381+ inputs .pop ("shift_factor" , None )
382+ with self .assertWarns (FutureWarning ) as warning :
383+ _ = remote_decode (output_type = "pt" , partial_postprocess = True , ** inputs )
384+ self .assertEqual (
385+ str (warning .warnings [0 ].message ),
386+ "`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required." ,
387+ str (warning .warnings [0 ].message ),
388+ )
389+
390+ def test_input_tensor_type_base64_deprecation (self ):
391+ inputs = self .get_dummy_inputs ()
392+ with self .assertWarns (FutureWarning ) as warning :
393+ _ = remote_decode (output_type = "pt" , input_tensor_type = "base64" , partial_postprocess = True , ** inputs )
394+ self .assertEqual (
395+ str (warning .warnings [0 ].message ),
396+ "input_tensor_type='base64' is deprecated. Using `binary`." ,
397+ str (warning .warnings [0 ].message ),
398+ )
399+
400+ def test_output_tensor_type_base64_deprecation (self ):
401+ inputs = self .get_dummy_inputs ()
402+ with self .assertWarns (FutureWarning ) as warning :
403+ _ = remote_decode (output_type = "pt" , output_tensor_type = "base64" , partial_postprocess = True , ** inputs )
404+ self .assertEqual (
405+ str (warning .warnings [0 ].message ),
406+ "output_tensor_type='base64' is deprecated. Using `binary`." ,
407+ str (warning .warnings [0 ].message ),
408+ )
409+
410+
203411class RemoteAutoencoderKLSDv1Tests (
204412 RemoteAutoencoderKLMixin ,
205413 unittest .TestCase ,
@@ -300,3 +508,29 @@ class RemoteAutoencoderKLFluxPackedTests(
300508 [168 , 212 , 202 , 155 , 191 , 185 , 150 , 180 , 168 ], dtype = torch .uint8
301509 )
302510 return_pt_slice = torch .tensor ([0.3198 , 0.6631 , 0.5864 , 0.2131 , 0.4944 , 0.4482 , 0.1776 , 0.4153 , 0.3176 ])
511+
512+
513+ class RemoteAutoencoderKLHunyuanVideoTests (
514+ RemoteAutoencoderKLHunyuanVideoMixin ,
515+ unittest .TestCase ,
516+ ):
517+ shape = (
518+ 1 ,
519+ 16 ,
520+ 3 ,
521+ 40 ,
522+ 64 ,
523+ )
524+ out_hw = (
525+ 320 ,
526+ 512 ,
527+ )
528+ endpoint = "https://lsx2injm3ts8wbvv.us-east-1.aws.endpoints.huggingface.cloud/"
529+ dtype = torch .float16
530+ scaling_factor = 0.476986
531+ processor_cls = VideoProcessor
532+ output_pt_slice = torch .tensor ([112 , 92 , 85 , 112 , 93 , 85 , 112 , 94 , 85 ], dtype = torch .uint8 )
533+ partial_postprocess_return_pt_slice = torch .tensor (
534+ [149 , 161 , 168 , 136 , 150 , 156 , 129 , 143 , 149 ], dtype = torch .uint8
535+ )
536+ return_pt_slice = torch .tensor ([0.1656 , 0.2661 , 0.3157 , 0.0693 , 0.1755 , 0.2252 , 0.0127 , 0.1221 , 0.1708 ])
0 commit comments