Skip to content

Commit 47498fb

Browse files
committed
Add remote_decode to remote_utils
1 parent cc7b5b8 commit 47498fb

File tree

2 files changed

+120
-0
lines changed

2 files changed

+120
-0
lines changed

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
unscale_lora_layers,
117117
)
118118
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
119+
from .remote_utils import remote_decode
119120
from .state_dict_utils import (
120121
convert_all_state_dict_to_peft,
121122
convert_state_dict_to_diffusers,
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import base64
2+
import io
3+
import json
4+
from typing import List, Literal, Optional, Union, cast
5+
6+
import requests
7+
import torch
8+
from PIL import Image
9+
from safetensors.torch import _tobytes
10+
11+
from ..image_processor import VaeImageProcessor
12+
from ..video_processor import VideoProcessor
13+
14+
15+
DTYPE_MAP = {
16+
"float16": torch.float16,
17+
"float32": torch.float32,
18+
"bfloat16": torch.bfloat16,
19+
"uint8": torch.uint8,
20+
}
21+
22+
23+
def remote_decode(
24+
endpoint: str,
25+
tensor: torch.Tensor,
26+
processor: Optional[Union[VaeImageProcessor, VideoProcessor]] = None,
27+
do_scaling: bool = True,
28+
output_type: Literal["mp4", "pil", "pt"] = "pil",
29+
image_format: Literal["png", "jpg"] = "jpg",
30+
partial_postprocess: bool = False,
31+
input_tensor_type: Literal["base64", "binary"] = "base64",
32+
output_tensor_type: Literal["base64", "binary"] = "base64",
33+
height: Optional[int] = None,
34+
width: Optional[int] = None,
35+
) -> Union[Image.Image, List[Image.Image], bytes, torch.Tensor]:
36+
if tensor.ndim == 3 and height is None and width is None:
37+
raise ValueError("`height` and `width` required for packed latents.")
38+
if output_type == "pt" and partial_postprocess is False and processor is None:
39+
raise ValueError("`processor` is required with `output_type='pt' and `partial_postprocess=False`.")
40+
headers = {}
41+
parameters = {
42+
"do_scaling": do_scaling,
43+
"output_type": output_type,
44+
"partial_postprocess": partial_postprocess,
45+
"shape": list(tensor.shape),
46+
"dtype": str(tensor.dtype).split(".")[-1],
47+
}
48+
if height is not None and width is not None:
49+
parameters["height"] = height
50+
parameters["width"] = width
51+
tensor_data = _tobytes(tensor, "tensor")
52+
if input_tensor_type == "base64":
53+
headers["Content-Type"] = "tensor/base64"
54+
elif input_tensor_type == "binary":
55+
headers["Content-Type"] = "tensor/binary"
56+
if output_type == "pil" and image_format == "jpg" and processor is None:
57+
headers["Accept"] = "image/jpeg"
58+
elif output_type == "pil" and image_format == "png" and processor is None:
59+
headers["Accept"] = "image/png"
60+
elif (output_tensor_type == "base64" and output_type == "pt") or (
61+
output_tensor_type == "base64" and output_type == "pil" and processor is not None
62+
):
63+
headers["Accept"] = "tensor/base64"
64+
elif (output_tensor_type == "binary" and output_type == "pt") or (
65+
output_tensor_type == "binary" and output_type == "pil" and processor is not None
66+
):
67+
headers["Accept"] = "tensor/binary"
68+
elif output_type == "mp4":
69+
headers["Accept"] = "text/plain"
70+
if input_tensor_type == "base64":
71+
kwargs = {"json": {"inputs": base64.b64encode(tensor_data).decode("utf-8")}}
72+
elif input_tensor_type == "binary":
73+
kwargs = {"data": tensor_data}
74+
response = requests.post(endpoint, params=parameters, **kwargs, headers=headers)
75+
if not response.ok:
76+
raise RuntimeError(response.json())
77+
if output_type == "pt" or (output_type == "pil" and processor is not None):
78+
if output_tensor_type == "base64":
79+
content = response.json()
80+
output_tensor = base64.b64decode(content["inputs"])
81+
parameters = content["parameters"]
82+
shape = parameters["shape"]
83+
dtype = parameters["dtype"]
84+
elif output_tensor_type == "binary":
85+
output_tensor = response.content
86+
parameters = response.headers
87+
shape = json.loads(parameters["shape"])
88+
dtype = parameters["dtype"]
89+
torch_dtype = DTYPE_MAP[dtype]
90+
output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
91+
if output_type == "pt":
92+
if partial_postprocess:
93+
output = [Image.fromarray(image.numpy()) for image in output_tensor]
94+
if len(output) == 1:
95+
output = output[0]
96+
else:
97+
if processor is None:
98+
output = output_tensor
99+
else:
100+
if isinstance(processor, VideoProcessor):
101+
output = cast(
102+
List[Image.Image],
103+
processor.postprocess_video(output_tensor, output_type="pil")[0],
104+
)
105+
else:
106+
output = cast(
107+
Image.Image,
108+
processor.postprocess(output_tensor, output_type="pil")[0],
109+
)
110+
elif output_type == "pil" and processor is None:
111+
output = Image.open(io.BytesIO(response.content)).convert("RGB")
112+
elif output_type == "pil" and processor is not None:
113+
output = [
114+
Image.fromarray(image)
115+
for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255).round().astype("uint8")
116+
]
117+
elif output_type == "mp4":
118+
output = response.content
119+
return output

0 commit comments

Comments
 (0)