Skip to content

Commit f80ef6d

Browse files
committed
changes
1 parent 2a2157e commit f80ef6d

File tree

1 file changed

+59
-18
lines changed

1 file changed

+59
-18
lines changed

src/diffusers/utils/remote_utils.py

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,31 @@
2828
from PIL import Image
2929

3030

31+
def check_inputs(
32+
endpoint: str,
33+
tensor: "torch.Tensor",
34+
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
35+
do_scaling: bool = True,
36+
output_type: Literal["mp4", "pil", "pt"] = "pil",
37+
return_type: Literal["mp4", "pil", "pt"] = "pil",
38+
image_format: Literal["png", "jpg"] = "jpg",
39+
partial_postprocess: bool = False,
40+
input_tensor_type: Literal["base64", "binary"] = "base64",
41+
output_tensor_type: Literal["base64", "binary"] = "base64",
42+
height: Optional[int] = None,
43+
width: Optional[int] = None,
44+
):
45+
if tensor.ndim == 3 and height is None and width is None:
46+
raise ValueError("`height` and `width` required for packed latents.")
47+
if (
48+
output_type == "pt"
49+
and return_type == "pil"
50+
and not partial_postprocess
51+
and not isinstance(processor, (VaeImageProcessor, VideoProcessor))
52+
):
53+
raise ValueError("`processor` is required.")
54+
55+
3156
def remote_decode(
3257
endpoint: str,
3358
tensor: "torch.Tensor",
@@ -63,7 +88,7 @@ def remote_decode(
6388
Requires `processor` as a flag (any `None` value will work).
6489
`"pt"`: Support by image and video models. Endpoint returns `torch.Tensor`.
6590
With `partial_postprocess=True` the tensor is postprocessed `uint8` image tensor.
66-
91+
6792
Recommendations:
6893
`"pt"` with `partial_postprocess=True` is the smallest transfer for full quality.
6994
`"pt"` with `partial_postprocess=False` is the most compatible with third party code.
@@ -85,28 +110,38 @@ def remote_decode(
85110
86111
image_format (`"png"` or `"jpg"`, default `jpg`):
87112
Used with `output_type="pil"`. Endpoint returns `jpg` or `png`.
88-
113+
89114
partial_postprocess (`bool`, default `False`):
90115
Used with `output_type="pt"`.
91116
`partial_postprocess=False` tensor is `float16` or `bfloat16`, without denormalization.
92117
`partial_postprocess=True` tensor is `uint8`, denormalized.
93-
118+
94119
input_tensor_type (`"base64"` or `"binary"`, default `"base64"`):
95120
With `"base64"` `tensor` is sent to endpoint base64 encoded. `"binary"` reduces overhead and transfer.
96121
97122
output_tensor_type (`"base64"` or `"binary"`, default `"base64"`):
98123
With `"base64"` `tensor` returned by endpoint is base64 encoded. `"binary"` reduces overhead and transfer.
99-
124+
100125
height (`int`, **optional**):
101126
Required for `"packed"` latents.
102127
103128
width (`int`, **optional**):
104129
Required for `"packed"` latents.
105130
"""
106-
if tensor.ndim == 3 and height is None and width is None:
107-
raise ValueError("`height` and `width` required for packed latents.")
108-
if output_type == "pt" and partial_postprocess is False and processor is None:
109-
raise ValueError("`processor` is required with `output_type='pt' and `partial_postprocess=False`.")
131+
check_inputs(
132+
endpoint,
133+
tensor,
134+
processor,
135+
do_scaling,
136+
output_type,
137+
return_type,
138+
image_format,
139+
partial_postprocess,
140+
input_tensor_type,
141+
output_tensor_type,
142+
height,
143+
width,
144+
)
110145
headers = {}
111146
parameters = {
112147
"do_scaling": do_scaling,
@@ -160,11 +195,14 @@ def remote_decode(
160195
output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
161196
if output_type == "pt":
162197
if partial_postprocess:
163-
output = [Image.fromarray(image.numpy()) for image in output_tensor]
164-
if len(output) == 1:
165-
output = output[0]
198+
if return_type == "pil":
199+
output = [Image.fromarray(image.numpy()) for image in output_tensor]
200+
if len(output) == 1:
201+
output = output[0]
202+
elif return_type == "pt":
203+
output = output_tensor
166204
else:
167-
if processor is None:
205+
if processor is None or return_type == "pt":
168206
output = output_tensor
169207
else:
170208
if isinstance(processor, VideoProcessor):
@@ -177,13 +215,16 @@ def remote_decode(
177215
Image.Image,
178216
processor.postprocess(output_tensor, output_type="pil")[0],
179217
)
180-
elif output_type == "pil" and processor is None:
218+
elif output_type == "pil" and return_type == "pil" and processor is None:
181219
output = Image.open(io.BytesIO(response.content)).convert("RGB")
182220
elif output_type == "pil" and processor is not None:
183-
output = [
184-
Image.fromarray(image)
185-
for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255).round().astype("uint8")
186-
]
187-
elif output_type == "mp4":
221+
if return_type == "pil":
222+
output = [
223+
Image.fromarray(image)
224+
for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255).round().astype("uint8")
225+
]
226+
elif return_type == "pt":
227+
output = output_tensor
228+
elif output_type == "mp4" and return_type == "mp4":
188229
output = response.content
189230
return output

0 commit comments

Comments
 (0)