Skip to content

Commit 562a4c0

Browse files
committed
hunyuanvideo test
1 parent 7df21f2 commit 562a4c0

File tree

2 files changed

+244
-0
lines changed

2 files changed

+244
-0
lines changed

src/diffusers/utils/remote_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def postprocess(
134134

135135
def prepare(
136136
tensor: "torch.Tensor",
137+
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
137138
do_scaling: bool = True,
138139
scaling_factor: Optional[float] = None,
139140
shift_factor: Optional[float] = None,
@@ -162,6 +163,14 @@ def prepare(
162163
if height is not None and width is not None:
163164
parameters["height"] = height
164165
parameters["width"] = width
166+
headers["Content-Type"] = "tensor/binary"
167+
headers["Accept"] = "tensor/binary"
168+
if output_type == "pil" and image_format == "jpg" and processor is None:
169+
headers["Accept"] = "image/jpeg"
170+
elif output_type == "pil" and image_format == "png" and processor is None:
171+
headers["Accept"] = "image/png"
172+
elif output_type == "mp4":
173+
headers["Accept"] = "text/plain"
165174
tensor_data = safetensors.torch._tobytes(tensor, "tensor")
166175
return {"data": tensor_data, "params": parameters, "headers": headers}
167176

@@ -291,6 +300,7 @@ def remote_decode(
291300
)
292301
kwargs = prepare(
293302
tensor=tensor,
303+
processor=processor,
294304
do_scaling=do_scaling,
295305
scaling_factor=scaling_factor,
296306
shift_factor=shift_factor,

tests/remote/test_remote_decode.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,214 @@ def test_output_tensor_type_base64_deprecation(self):
200200
)
201201

202202

203+
class RemoteAutoencoderKLHunyuanVideoMixin:
204+
shape: Tuple[int, ...] = None
205+
out_hw: Tuple[int, int] = None
206+
endpoint: str = None
207+
dtype: torch.dtype = None
208+
scaling_factor: float = None
209+
shift_factor: float = None
210+
processor_cls: Union[VaeImageProcessor, VideoProcessor] = None
211+
output_pil_slice: torch.Tensor = None
212+
output_pt_slice: torch.Tensor = None
213+
partial_postprocess_return_pt_slice: torch.Tensor = None
214+
return_pt_slice: torch.Tensor = None
215+
width: int = None
216+
height: int = None
217+
218+
def get_dummy_inputs(self):
219+
inputs = {
220+
"endpoint": self.endpoint,
221+
"tensor": torch.randn(
222+
self.shape,
223+
device=torch_device,
224+
dtype=self.dtype,
225+
generator=torch.Generator(torch_device).manual_seed(13),
226+
),
227+
"scaling_factor": self.scaling_factor,
228+
"shift_factor": self.shift_factor,
229+
"height": self.height,
230+
"width": self.width,
231+
}
232+
return inputs
233+
234+
def test_no_scaling(self):
235+
inputs = self.get_dummy_inputs()
236+
if inputs["scaling_factor"] is not None:
237+
inputs["tensor"] = inputs["tensor"] / inputs["scaling_factor"]
238+
inputs["scaling_factor"] = None
239+
if inputs["shift_factor"] is not None:
240+
inputs["tensor"] = inputs["tensor"] + inputs["shift_factor"]
241+
inputs["shift_factor"] = None
242+
processor = self.processor_cls()
243+
output = remote_decode(
244+
output_type="pt",
245+
# required for now, will be removed in next update
246+
do_scaling=False,
247+
processor=processor,
248+
**inputs,
249+
)
250+
self.assertTrue(
251+
isinstance(output, list) and isinstance(output[0], PIL.Image.Image),
252+
f"Expected `List[PIL.Image.Image]` output, got {type(output)}",
253+
)
254+
self.assertEqual(
255+
output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}"
256+
)
257+
self.assertEqual(
258+
output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}"
259+
)
260+
output_slice = torch.from_numpy(np.array(output[0])[0, -3:, -3:].flatten())
261+
self.assertTrue(
262+
torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1),
263+
f"{output_slice}",
264+
)
265+
266+
def test_output_type_pt(self):
267+
inputs = self.get_dummy_inputs()
268+
processor = self.processor_cls()
269+
output = remote_decode(output_type="pt", processor=processor, **inputs)
270+
self.assertTrue(
271+
isinstance(output, list) and isinstance(output[0], PIL.Image.Image),
272+
f"Expected `List[PIL.Image.Image]` output, got {type(output)}",
273+
)
274+
self.assertEqual(
275+
output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}"
276+
)
277+
self.assertEqual(
278+
output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}"
279+
)
280+
output_slice = torch.from_numpy(np.array(output[0])[0, -3:, -3:].flatten())
281+
self.assertTrue(
282+
torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1),
283+
f"{output_slice}",
284+
)
285+
286+
# output is visually the same, slice is flaky?
287+
def test_output_type_pil(self):
288+
inputs = self.get_dummy_inputs()
289+
processor = self.processor_cls()
290+
output = remote_decode(output_type="pil", processor=processor, **inputs)
291+
self.assertTrue(
292+
isinstance(output, list) and isinstance(output[0], PIL.Image.Image),
293+
f"Expected `List[PIL.Image.Image]` output, got {type(output)}",
294+
)
295+
self.assertEqual(
296+
output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}"
297+
)
298+
self.assertEqual(
299+
output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}"
300+
)
301+
302+
def test_output_type_pil_image_format(self):
303+
inputs = self.get_dummy_inputs()
304+
processor = self.processor_cls()
305+
output = remote_decode(output_type="pil", processor=processor, image_format="png", **inputs)
306+
self.assertTrue(
307+
isinstance(output, list) and isinstance(output[0], PIL.Image.Image),
308+
f"Expected `List[PIL.Image.Image]` output, got {type(output)}",
309+
)
310+
self.assertEqual(
311+
output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}"
312+
)
313+
self.assertEqual(
314+
output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}"
315+
)
316+
output_slice = torch.from_numpy(np.array(output[0])[0, -3:, -3:].flatten())
317+
self.assertTrue(
318+
torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1),
319+
f"{output_slice}",
320+
)
321+
322+
def test_output_type_pt_partial_postprocess(self):
323+
inputs = self.get_dummy_inputs()
324+
output = remote_decode(output_type="pt", partial_postprocess=True, **inputs)
325+
self.assertTrue(
326+
isinstance(output, list) and isinstance(output[0], PIL.Image.Image),
327+
f"Expected `List[PIL.Image.Image]` output, got {type(output)}",
328+
)
329+
self.assertEqual(
330+
output[0].height, self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output[0].height}"
331+
)
332+
self.assertEqual(
333+
output[0].width, self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output[0].width}"
334+
)
335+
output_slice = torch.from_numpy(np.array(output[0])[0, -3:, -3:].flatten())
336+
self.assertTrue(
337+
torch_all_close(output_slice, self.output_pt_slice.to(output_slice.dtype), rtol=1, atol=1),
338+
f"{output_slice}",
339+
)
340+
341+
def test_output_type_pt_return_type_pt(self):
342+
inputs = self.get_dummy_inputs()
343+
output = remote_decode(output_type="pt", return_type="pt", **inputs)
344+
self.assertTrue(isinstance(output, torch.Tensor), f"Expected `torch.Tensor` output, got {type(output)}")
345+
self.assertEqual(
346+
output.shape[3], self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.shape[2]}"
347+
)
348+
self.assertEqual(
349+
output.shape[4], self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.shape[3]}"
350+
)
351+
output_slice = output[0, 0, 0, -3:, -3:].flatten()
352+
self.assertTrue(
353+
torch_all_close(output_slice, self.return_pt_slice.to(output_slice.dtype), rtol=1e-3, atol=1e-3),
354+
f"{output_slice}",
355+
)
356+
357+
def test_output_type_pt_partial_postprocess_return_type_pt(self):
358+
inputs = self.get_dummy_inputs()
359+
output = remote_decode(output_type="pt", partial_postprocess=True, return_type="pt", **inputs)
360+
self.assertTrue(isinstance(output, torch.Tensor), f"Expected `torch.Tensor` output, got {type(output)}")
361+
self.assertEqual(
362+
output.shape[1], self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.shape[1]}"
363+
)
364+
self.assertEqual(
365+
output.shape[2], self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.shape[2]}"
366+
)
367+
output_slice = output[0, -3:, -3:, 0].flatten().cpu()
368+
self.assertTrue(
369+
torch_all_close(output_slice, self.partial_postprocess_return_pt_slice.to(output_slice.dtype), rtol=1e-2),
370+
f"{output_slice}",
371+
)
372+
373+
def test_output_type_mp4(self):
374+
inputs = self.get_dummy_inputs()
375+
output = remote_decode(output_type="mp4", return_type="mp4", **inputs)
376+
self.assertTrue(isinstance(output, bytes), f"Expected `bytes` output, got {type(output)}")
377+
378+
def test_do_scaling_deprecation(self):
379+
inputs = self.get_dummy_inputs()
380+
inputs.pop("scaling_factor", None)
381+
inputs.pop("shift_factor", None)
382+
with self.assertWarns(FutureWarning) as warning:
383+
_ = remote_decode(output_type="pt", partial_postprocess=True, **inputs)
384+
self.assertEqual(
385+
str(warning.warnings[0].message),
386+
"`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.",
387+
str(warning.warnings[0].message),
388+
)
389+
390+
def test_input_tensor_type_base64_deprecation(self):
391+
inputs = self.get_dummy_inputs()
392+
with self.assertWarns(FutureWarning) as warning:
393+
_ = remote_decode(output_type="pt", input_tensor_type="base64", partial_postprocess=True, **inputs)
394+
self.assertEqual(
395+
str(warning.warnings[0].message),
396+
"input_tensor_type='base64' is deprecated. Using `binary`.",
397+
str(warning.warnings[0].message),
398+
)
399+
400+
def test_output_tensor_type_base64_deprecation(self):
401+
inputs = self.get_dummy_inputs()
402+
with self.assertWarns(FutureWarning) as warning:
403+
_ = remote_decode(output_type="pt", output_tensor_type="base64", partial_postprocess=True, **inputs)
404+
self.assertEqual(
405+
str(warning.warnings[0].message),
406+
"output_tensor_type='base64' is deprecated. Using `binary`.",
407+
str(warning.warnings[0].message),
408+
)
409+
410+
203411
class RemoteAutoencoderKLSDv1Tests(
204412
RemoteAutoencoderKLMixin,
205413
unittest.TestCase,
@@ -300,3 +508,29 @@ class RemoteAutoencoderKLFluxPackedTests(
300508
[168, 212, 202, 155, 191, 185, 150, 180, 168], dtype=torch.uint8
301509
)
302510
return_pt_slice = torch.tensor([0.3198, 0.6631, 0.5864, 0.2131, 0.4944, 0.4482, 0.1776, 0.4153, 0.3176])
511+
512+
513+
class RemoteAutoencoderKLHunyuanVideoTests(
514+
RemoteAutoencoderKLHunyuanVideoMixin,
515+
unittest.TestCase,
516+
):
517+
shape = (
518+
1,
519+
16,
520+
3,
521+
40,
522+
64,
523+
)
524+
out_hw = (
525+
320,
526+
512,
527+
)
528+
endpoint = "https://lsx2injm3ts8wbvv.us-east-1.aws.endpoints.huggingface.cloud/"
529+
dtype = torch.float16
530+
scaling_factor = 0.476986
531+
processor_cls = VideoProcessor
532+
output_pt_slice = torch.tensor([112, 92, 85, 112, 93, 85, 112, 94, 85], dtype=torch.uint8)
533+
partial_postprocess_return_pt_slice = torch.tensor(
534+
[149, 161, 168, 136, 150, 156, 129, 143, 149], dtype=torch.uint8
535+
)
536+
return_pt_slice = torch.tensor([0.1656, 0.2661, 0.3157, 0.0693, 0.1755, 0.2252, 0.0127, 0.1221, 0.1708])

0 commit comments

Comments
 (0)