Skip to content

Commit 79620bf

Browse files
committed
[hybrid inference 🍯🐝] Wan 2.1 decode
1 parent 73adcd8 commit 79620bf

File tree

5 files changed

+33
-23
lines changed

5 files changed

+33
-23
lines changed

docs/source/en/hybrid_inference/overview.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
4848

4949
## Changelog
5050

51+
- March 11 2025: Added Wan 2.1 VAE decode
5152
- March 10 2025: Added VAE encode
5253
- March 2 2025: Initial release with VAE decoding
5354

docs/source/en/hybrid_inference/vae_decode.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ For the majority of these GPUs the memory usage % dictates other models (text en
5454
| **Stable Diffusion XL** | [https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud](https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
5555
| **Flux** | [https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud](https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
5656
| **HunyuanVideo** | [https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud](https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud) | [`hunyuanvideo-community/HunyuanVideo`](https://hf.co/hunyuanvideo-community/HunyuanVideo) |
57+
| **Wan2.1** | [https://lafotb093i5cnx2w.us-east-1.aws.endpoints.huggingface.cloud](https://lafotb093i5cnx2w.us-east-1.aws.endpoints.huggingface.cloud) | [`Wan-AI/Wan2.1-T2V-1.3B-Diffusers`](https://hf.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers) |
5758

5859

5960
> [!TIP]

src/diffusers/utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
DECODE_ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/"
6363
DECODE_ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/"
6464
DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/"
65-
65+
DECODE_ENDPOINT_WAN_2_1 = "https://lafotb093i5cnx2w.us-east-1.aws.endpoints.huggingface.cloud/"
6666

6767
ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
6868
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"

src/diffusers/utils/remote_utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,6 @@ def check_inputs_decode(
8080
and not isinstance(processor, (VaeImageProcessor, VideoProcessor))
8181
):
8282
raise ValueError("`processor` is required.")
83-
if do_scaling and scaling_factor is None:
84-
deprecate(
85-
"do_scaling",
86-
"1.0.0",
87-
"`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.",
88-
standard_warn=False,
89-
)
9083

9184

9285
def postprocess_decode(

tests/remote/test_remote_decode.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
DECODE_ENDPOINT_HUNYUAN_VIDEO,
2727
DECODE_ENDPOINT_SD_V1,
2828
DECODE_ENDPOINT_SD_XL,
29+
DECODE_ENDPOINT_WAN_2_1,
2930
)
3031
from diffusers.utils.remote_utils import (
3132
remote_decode,
@@ -176,18 +177,6 @@ def test_output_type_pt_partial_postprocess_return_type_pt(self):
176177
f"{output_slice}",
177178
)
178179

179-
def test_do_scaling_deprecation(self):
180-
inputs = self.get_dummy_inputs()
181-
inputs.pop("scaling_factor", None)
182-
inputs.pop("shift_factor", None)
183-
with self.assertWarns(FutureWarning) as warning:
184-
_ = remote_decode(output_type="pt", partial_postprocess=True, **inputs)
185-
self.assertEqual(
186-
str(warning.warnings[0].message),
187-
"`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.",
188-
str(warning.warnings[0].message),
189-
)
190-
191180
def test_input_tensor_type_base64_deprecation(self):
192181
inputs = self.get_dummy_inputs()
193182
with self.assertWarns(FutureWarning) as warning:
@@ -209,7 +198,7 @@ def test_output_tensor_type_base64_deprecation(self):
209198
)
210199

211200

212-
class RemoteAutoencoderKLHunyuanVideoMixin(RemoteAutoencoderKLMixin):
201+
class RemoteAutoencoderKLVideoMixin(RemoteAutoencoderKLMixin):
213202
def test_no_scaling(self):
214203
inputs = self.get_dummy_inputs()
215204
if inputs["scaling_factor"] is not None:
@@ -221,7 +210,6 @@ def test_no_scaling(self):
221210
processor = self.processor_cls()
222211
output = remote_decode(
223212
output_type="pt",
224-
# required for now, will be removed in next update
225213
do_scaling=False,
226214
processor=processor,
227215
**inputs,
@@ -337,6 +325,8 @@ def test_output_type_mp4(self):
337325
inputs = self.get_dummy_inputs()
338326
output = remote_decode(output_type="mp4", return_type="mp4", **inputs)
339327
self.assertTrue(isinstance(output, bytes), f"Expected `bytes` output, got {type(output)}")
328+
with open("test.mp4", "wb") as f:
329+
f.write(output)
340330

341331

342332
class RemoteAutoencoderKLSDv1Tests(
@@ -442,7 +432,7 @@ class RemoteAutoencoderKLFluxPackedTests(
442432

443433

444434
class RemoteAutoencoderKLHunyuanVideoTests(
445-
RemoteAutoencoderKLHunyuanVideoMixin,
435+
RemoteAutoencoderKLVideoMixin,
446436
unittest.TestCase,
447437
):
448438
shape = (
@@ -467,6 +457,31 @@ class RemoteAutoencoderKLHunyuanVideoTests(
467457
return_pt_slice = torch.tensor([0.1656, 0.2661, 0.3157, 0.0693, 0.1755, 0.2252, 0.0127, 0.1221, 0.1708])
468458

469459

460+
class RemoteAutoencoderKLWanTests(
461+
RemoteAutoencoderKLVideoMixin,
462+
unittest.TestCase,
463+
):
464+
shape = (
465+
1,
466+
16,
467+
3,
468+
40,
469+
64,
470+
)
471+
out_hw = (
472+
320,
473+
512,
474+
)
475+
endpoint = DECODE_ENDPOINT_WAN_2_1
476+
dtype = torch.float16
477+
processor_cls = VideoProcessor
478+
output_pt_slice = torch.tensor([203, 174, 178, 204, 171, 177, 209, 183, 182], dtype=torch.uint8)
479+
partial_postprocess_return_pt_slice = torch.tensor(
480+
[206, 209, 221, 202, 199, 222, 207, 210, 217], dtype=torch.uint8
481+
)
482+
return_pt_slice = torch.tensor([0.6196, 0.6382, 0.7310, 0.5869, 0.5625, 0.7373, 0.6240, 0.6465, 0.7002])
483+
484+
470485
class RemoteAutoencoderKLSlowTestMixin:
471486
channels: int = 4
472487
endpoint: str = None

0 commit comments

Comments
 (0)