Skip to content

Commit 3f69f92

Browse files
committed
comments
1 parent 5302645 commit 3f69f92

File tree

4 files changed

+126
-189
lines changed

4 files changed

+126
-189
lines changed

docs/source/en/_toctree.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,6 @@
112112
- local: using-diffusers/marigold_usage
113113
title: Marigold Computer Vision
114114
title: Specific pipeline examples
115-
- sections:
116-
- local: hybrid_inference/overview
117-
title: Overview
118-
title: Hybrid Inference
119115
- sections:
120116
- local: training/overview
121117
title: Overview
@@ -632,3 +628,7 @@
632628
title: Video Processor
633629
title: Internal classes
634630
title: API
631+
- sections:
632+
- local: hybrid_inference/overview
633+
title: Overview
634+
title: Hybrid Inference

docs/source/en/hybrid_inference/overview.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ specific language governing permissions and limitations under the License.
1414

1515
**Empowering local AI builders with Hybrid Inference**
1616

17+
> [!TIP]
18+
> [Hybrid Inference](https://huggingface.co/blog/remote_vae) is an experimental feature.
19+
1720
---
1821

1922
## Why use Hybrid Inference?
@@ -24,6 +27,8 @@ Hybrid Inference offers a fast and simple way to offload local generation requir
2427
* **VAE Encode (coming soon):** Encode images to latents for generation or training.
2528
* **Text Encoders (coming soon):** Compute text embeddings for prompts without comprimising quality or slowing down your workflow.
2629

30+
Feedback can be provided [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
31+
2732
---
2833

2934
## Key Benefits

src/diffusers/utils/remote_utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
# TODO: `imghdr` is deprecated in Python 3.13 🙄
17-
import imghdr
1816
import io
1917
import json
2018
from typing import List, Literal, Optional, Union, cast
@@ -45,6 +43,18 @@
4543
from PIL import Image
4644

4745

46+
def detect_image_type(data: bytes) -> str:
47+
if data.startswith(b"\xff\xd8"):
48+
return "jpeg"
49+
elif data.startswith(b"\x89PNG\r\n\x1a\n"):
50+
return "png"
51+
elif data.startswith(b"GIF87a") or data.startswith(b"GIF89a"):
52+
return "gif"
53+
elif data.startswith(b"BM"):
54+
return "bmp"
55+
return "unknown"
56+
57+
4858
def check_inputs(
4959
endpoint: str,
5060
tensor: "torch.Tensor",
@@ -117,7 +127,7 @@ def postprocess(
117127
)
118128
elif output_type == "pil" and return_type == "pil" and processor is None:
119129
output = Image.open(io.BytesIO(response.content)).convert("RGB")
120-
detected_format = imghdr.what(None, h=response.content)
130+
detected_format = detect_image_type(response.content)
121131
output.format = detected_format
122132
elif output_type == "pil" and processor is not None:
123133
if return_type == "pil":
@@ -207,7 +217,7 @@ def remote_decode(
207217
/ self.vae.config.scaling_factor` is applied remotely. If `False`, input must be passed with scaling
208218
applied.
209219
scaling_factor (`float`, *optional*):
210-
Scaling is applied when passed e.g. `latents / self.vae.config.scaling_factor`.
220+
Scaling is applied when passed e.g. [`latents / self.vae.config.scaling_factor`](https://github.com/huggingface/diffusers/blob/7007febae5cff000d4df9059d9cf35133e8b2ca9/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L1083C37-L1083C77).
211221
- SD v1: 0.18215
212222
- SD XL: 0.13025
213223
- Flux: 0.3611

tests/remote/test_remote_decode.py

Lines changed: 103 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -200,37 +200,7 @@ def test_output_tensor_type_base64_deprecation(self):
200200
)
201201

202202

203-
class RemoteAutoencoderKLHunyuanVideoMixin:
204-
shape: Tuple[int, ...] = None
205-
out_hw: Tuple[int, int] = None
206-
endpoint: str = None
207-
dtype: torch.dtype = None
208-
scaling_factor: float = None
209-
shift_factor: float = None
210-
processor_cls: Union[VaeImageProcessor, VideoProcessor] = None
211-
output_pil_slice: torch.Tensor = None
212-
output_pt_slice: torch.Tensor = None
213-
partial_postprocess_return_pt_slice: torch.Tensor = None
214-
return_pt_slice: torch.Tensor = None
215-
width: int = None
216-
height: int = None
217-
218-
def get_dummy_inputs(self):
219-
inputs = {
220-
"endpoint": self.endpoint,
221-
"tensor": torch.randn(
222-
self.shape,
223-
device=torch_device,
224-
dtype=self.dtype,
225-
generator=torch.Generator(torch_device).manual_seed(13),
226-
),
227-
"scaling_factor": self.scaling_factor,
228-
"shift_factor": self.shift_factor,
229-
"height": self.height,
230-
"width": self.width,
231-
}
232-
return inputs
233-
203+
class RemoteAutoencoderKLHunyuanVideoMixin(RemoteAutoencoderKLMixin):
234204
def test_no_scaling(self):
235205
inputs = self.get_dummy_inputs()
236206
if inputs["scaling_factor"] is not None:
@@ -354,59 +324,11 @@ def test_output_type_pt_return_type_pt(self):
354324
f"{output_slice}",
355325
)
356326

357-
def test_output_type_pt_partial_postprocess_return_type_pt(self):
358-
inputs = self.get_dummy_inputs()
359-
output = remote_decode(output_type="pt", partial_postprocess=True, return_type="pt", **inputs)
360-
self.assertTrue(isinstance(output, torch.Tensor), f"Expected `torch.Tensor` output, got {type(output)}")
361-
self.assertEqual(
362-
output.shape[1], self.out_hw[0], f"Expected image height {self.out_hw[0]}, got {output.shape[1]}"
363-
)
364-
self.assertEqual(
365-
output.shape[2], self.out_hw[1], f"Expected image width {self.out_hw[0]}, got {output.shape[2]}"
366-
)
367-
output_slice = output[0, -3:, -3:, 0].flatten().cpu()
368-
self.assertTrue(
369-
torch_all_close(output_slice, self.partial_postprocess_return_pt_slice.to(output_slice.dtype), rtol=1e-2),
370-
f"{output_slice}",
371-
)
372-
373327
def test_output_type_mp4(self):
374328
inputs = self.get_dummy_inputs()
375329
output = remote_decode(output_type="mp4", return_type="mp4", **inputs)
376330
self.assertTrue(isinstance(output, bytes), f"Expected `bytes` output, got {type(output)}")
377331

378-
def test_do_scaling_deprecation(self):
379-
inputs = self.get_dummy_inputs()
380-
inputs.pop("scaling_factor", None)
381-
inputs.pop("shift_factor", None)
382-
with self.assertWarns(FutureWarning) as warning:
383-
_ = remote_decode(output_type="pt", partial_postprocess=True, **inputs)
384-
self.assertEqual(
385-
str(warning.warnings[0].message),
386-
"`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.",
387-
str(warning.warnings[0].message),
388-
)
389-
390-
def test_input_tensor_type_base64_deprecation(self):
391-
inputs = self.get_dummy_inputs()
392-
with self.assertWarns(FutureWarning) as warning:
393-
_ = remote_decode(output_type="pt", input_tensor_type="base64", partial_postprocess=True, **inputs)
394-
self.assertEqual(
395-
str(warning.warnings[0].message),
396-
"input_tensor_type='base64' is deprecated. Using `binary`.",
397-
str(warning.warnings[0].message),
398-
)
399-
400-
def test_output_tensor_type_base64_deprecation(self):
401-
inputs = self.get_dummy_inputs()
402-
with self.assertWarns(FutureWarning) as warning:
403-
_ = remote_decode(output_type="pt", output_tensor_type="base64", partial_postprocess=True, **inputs)
404-
self.assertEqual(
405-
str(warning.warnings[0].message),
406-
"output_tensor_type='base64' is deprecated. Using `binary`.",
407-
str(warning.warnings[0].message),
408-
)
409-
410332

411333
class RemoteAutoencoderKLSDv1Tests(
412334
RemoteAutoencoderKLMixin,
@@ -432,105 +354,105 @@ class RemoteAutoencoderKLSDv1Tests(
432354
return_pt_slice = torch.tensor([-0.2177, 0.0217, -0.2258, 0.0412, -0.1687, -0.1232, -0.2416, -0.2130, -0.0543])
433355

434356

435-
class RemoteAutoencoderKLSDXLTests(
436-
RemoteAutoencoderKLMixin,
437-
unittest.TestCase,
438-
):
439-
shape = (
440-
1,
441-
4,
442-
128,
443-
128,
444-
)
445-
out_hw = (
446-
1024,
447-
1024,
448-
)
449-
endpoint = "https://fagf07t3bwf0615i.us-east-1.aws.endpoints.huggingface.cloud/"
450-
dtype = torch.float16
451-
scaling_factor = 0.13025
452-
shift_factor = None
453-
processor_cls = VaeImageProcessor
454-
output_pt_slice = torch.tensor([104, 52, 23, 114, 61, 35, 108, 87, 38], dtype=torch.uint8)
455-
partial_postprocess_return_pt_slice = torch.tensor([77, 86, 89, 49, 60, 75, 52, 65, 78], dtype=torch.uint8)
456-
return_pt_slice = torch.tensor([-0.3945, -0.3289, -0.2993, -0.6177, -0.5259, -0.4119, -0.5898, -0.4863, -0.3845])
457-
458-
459-
class RemoteAutoencoderKLFluxTests(
460-
RemoteAutoencoderKLMixin,
461-
unittest.TestCase,
462-
):
463-
shape = (
464-
1,
465-
16,
466-
128,
467-
128,
468-
)
469-
out_hw = (
470-
1024,
471-
1024,
472-
)
473-
endpoint = "https://fnohtuwsskxgxsnn.us-east-1.aws.endpoints.huggingface.cloud/"
474-
dtype = torch.bfloat16
475-
scaling_factor = 0.3611
476-
shift_factor = 0.1159
477-
processor_cls = VaeImageProcessor
478-
output_pt_slice = torch.tensor([110, 72, 91, 62, 35, 52, 69, 55, 69], dtype=torch.uint8)
479-
partial_postprocess_return_pt_slice = torch.tensor(
480-
[202, 203, 203, 197, 195, 193, 189, 188, 178], dtype=torch.uint8
481-
)
482-
return_pt_slice = torch.tensor([0.5820, 0.5962, 0.5898, 0.5439, 0.5327, 0.5112, 0.4797, 0.4773, 0.3984])
483-
484-
485-
class RemoteAutoencoderKLFluxPackedTests(
486-
RemoteAutoencoderKLMixin,
487-
unittest.TestCase,
488-
):
489-
shape = (
490-
1,
491-
4096,
492-
64,
493-
)
494-
out_hw = (
495-
1024,
496-
1024,
497-
)
498-
height = 1024
499-
width = 1024
500-
endpoint = "https://fnohtuwsskxgxsnn.us-east-1.aws.endpoints.huggingface.cloud/"
501-
dtype = torch.bfloat16
502-
scaling_factor = 0.3611
503-
shift_factor = 0.1159
504-
processor_cls = VaeImageProcessor
505-
# slices are different due to randn on different shape. we can pack the latent instead if we want the same
506-
output_pt_slice = torch.tensor([96, 116, 157, 45, 67, 104, 34, 56, 89], dtype=torch.uint8)
507-
partial_postprocess_return_pt_slice = torch.tensor(
508-
[168, 212, 202, 155, 191, 185, 150, 180, 168], dtype=torch.uint8
509-
)
510-
return_pt_slice = torch.tensor([0.3198, 0.6631, 0.5864, 0.2131, 0.4944, 0.4482, 0.1776, 0.4153, 0.3176])
511-
512-
513-
class RemoteAutoencoderKLHunyuanVideoTests(
514-
RemoteAutoencoderKLHunyuanVideoMixin,
515-
unittest.TestCase,
516-
):
517-
shape = (
518-
1,
519-
16,
520-
3,
521-
40,
522-
64,
523-
)
524-
out_hw = (
525-
320,
526-
512,
527-
)
528-
endpoint = "https://lsx2injm3ts8wbvv.us-east-1.aws.endpoints.huggingface.cloud/"
529-
dtype = torch.float16
530-
scaling_factor = 0.476986
531-
processor_cls = VideoProcessor
532-
output_pt_slice = torch.tensor([112, 92, 85, 112, 93, 85, 112, 94, 85], dtype=torch.uint8)
533-
partial_postprocess_return_pt_slice = torch.tensor(
534-
[149, 161, 168, 136, 150, 156, 129, 143, 149], dtype=torch.uint8
535-
)
536-
return_pt_slice = torch.tensor([0.1656, 0.2661, 0.3157, 0.0693, 0.1755, 0.2252, 0.0127, 0.1221, 0.1708])
357+
# class RemoteAutoencoderKLSDXLTests(
358+
# RemoteAutoencoderKLMixin,
359+
# unittest.TestCase,
360+
# ):
361+
# shape = (
362+
# 1,
363+
# 4,
364+
# 128,
365+
# 128,
366+
# )
367+
# out_hw = (
368+
# 1024,
369+
# 1024,
370+
# )
371+
# endpoint = "https://fagf07t3bwf0615i.us-east-1.aws.endpoints.huggingface.cloud/"
372+
# dtype = torch.float16
373+
# scaling_factor = 0.13025
374+
# shift_factor = None
375+
# processor_cls = VaeImageProcessor
376+
# output_pt_slice = torch.tensor([104, 52, 23, 114, 61, 35, 108, 87, 38], dtype=torch.uint8)
377+
# partial_postprocess_return_pt_slice = torch.tensor([77, 86, 89, 49, 60, 75, 52, 65, 78], dtype=torch.uint8)
378+
# return_pt_slice = torch.tensor([-0.3945, -0.3289, -0.2993, -0.6177, -0.5259, -0.4119, -0.5898, -0.4863, -0.3845])
379+
380+
381+
# class RemoteAutoencoderKLFluxTests(
382+
# RemoteAutoencoderKLMixin,
383+
# unittest.TestCase,
384+
# ):
385+
# shape = (
386+
# 1,
387+
# 16,
388+
# 128,
389+
# 128,
390+
# )
391+
# out_hw = (
392+
# 1024,
393+
# 1024,
394+
# )
395+
# endpoint = "https://fnohtuwsskxgxsnn.us-east-1.aws.endpoints.huggingface.cloud/"
396+
# dtype = torch.bfloat16
397+
# scaling_factor = 0.3611
398+
# shift_factor = 0.1159
399+
# processor_cls = VaeImageProcessor
400+
# output_pt_slice = torch.tensor([110, 72, 91, 62, 35, 52, 69, 55, 69], dtype=torch.uint8)
401+
# partial_postprocess_return_pt_slice = torch.tensor(
402+
# [202, 203, 203, 197, 195, 193, 189, 188, 178], dtype=torch.uint8
403+
# )
404+
# return_pt_slice = torch.tensor([0.5820, 0.5962, 0.5898, 0.5439, 0.5327, 0.5112, 0.4797, 0.4773, 0.3984])
405+
406+
407+
# class RemoteAutoencoderKLFluxPackedTests(
408+
# RemoteAutoencoderKLMixin,
409+
# unittest.TestCase,
410+
# ):
411+
# shape = (
412+
# 1,
413+
# 4096,
414+
# 64,
415+
# )
416+
# out_hw = (
417+
# 1024,
418+
# 1024,
419+
# )
420+
# height = 1024
421+
# width = 1024
422+
# endpoint = "https://fnohtuwsskxgxsnn.us-east-1.aws.endpoints.huggingface.cloud/"
423+
# dtype = torch.bfloat16
424+
# scaling_factor = 0.3611
425+
# shift_factor = 0.1159
426+
# processor_cls = VaeImageProcessor
427+
# # slices are different due to randn on different shape. we can pack the latent instead if we want the same
428+
# output_pt_slice = torch.tensor([96, 116, 157, 45, 67, 104, 34, 56, 89], dtype=torch.uint8)
429+
# partial_postprocess_return_pt_slice = torch.tensor(
430+
# [168, 212, 202, 155, 191, 185, 150, 180, 168], dtype=torch.uint8
431+
# )
432+
# return_pt_slice = torch.tensor([0.3198, 0.6631, 0.5864, 0.2131, 0.4944, 0.4482, 0.1776, 0.4153, 0.3176])
433+
434+
435+
# class RemoteAutoencoderKLHunyuanVideoTests(
436+
# RemoteAutoencoderKLHunyuanVideoMixin,
437+
# unittest.TestCase,
438+
# ):
439+
# shape = (
440+
# 1,
441+
# 16,
442+
# 3,
443+
# 40,
444+
# 64,
445+
# )
446+
# out_hw = (
447+
# 320,
448+
# 512,
449+
# )
450+
# endpoint = "https://lsx2injm3ts8wbvv.us-east-1.aws.endpoints.huggingface.cloud/"
451+
# dtype = torch.float16
452+
# scaling_factor = 0.476986
453+
# processor_cls = VideoProcessor
454+
# output_pt_slice = torch.tensor([112, 92, 85, 112, 93, 85, 112, 94, 85], dtype=torch.uint8)
455+
# partial_postprocess_return_pt_slice = torch.tensor(
456+
# [149, 161, 168, 136, 150, 156, 129, 143, 149], dtype=torch.uint8
457+
# )
458+
# return_pt_slice = torch.tensor([0.1656, 0.2661, 0.3157, 0.0693, 0.1755, 0.2252, 0.0127, 0.1221, 0.1708])

0 commit comments

Comments
 (0)