Skip to content

Commit e55139b

Browse files
committed
revert, add new options
1 parent 1978a8a commit e55139b

File tree

1 file changed

+115
-111
lines changed

1 file changed

+115
-111
lines changed

src/diffusers/utils/remote_utils.py

Lines changed: 115 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from typing import List, Literal, Optional, Union, cast
55

66
import requests
7-
from PIL import Image
87

8+
from .deprecation_utils import deprecate
99
from .import_utils import is_safetensors_available, is_torch_available
1010

1111

@@ -16,7 +16,8 @@
1616
from ..video_processor import VideoProcessor
1717

1818
if is_safetensors_available():
19-
import safetensors
19+
import safetensors.torch
20+
2021
DTYPE_MAP = {
2122
"float16": torch.float16,
2223
"float32": torch.float32,
@@ -25,11 +26,16 @@
2526
}
2627

2728

29+
from PIL import Image
30+
31+
2832
def check_inputs(
2933
endpoint: str,
3034
tensor: "torch.Tensor",
3135
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
3236
do_scaling: bool = True,
37+
scaling_factor: Optional[float] = None,
38+
shift_factor: Optional[float] = None,
3339
output_type: Literal["mp4", "pil", "pt"] = "pil",
3440
return_type: Literal["mp4", "pil", "pt"] = "pil",
3541
image_format: Literal["png", "jpg"] = "jpg",
@@ -48,81 +54,22 @@ def check_inputs(
4854
and not isinstance(processor, (VaeImageProcessor, VideoProcessor))
4955
):
5056
raise ValueError("`processor` is required.")
51-
52-
53-
def _prepare_headers(
54-
input_tensor_type: Literal["base64", "binary"],
55-
output_type: Literal["mp4", "pil", "pt"],
56-
image_format: Literal["png", "jpg"],
57-
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]],
58-
output_tensor_type: Literal["base64", "binary"],
59-
) -> dict:
60-
headers = {}
61-
headers["Content-Type"] = "tensor/base64" if input_tensor_type == "base64" else "tensor/binary"
62-
63-
if output_type == "pil":
64-
if processor is None:
65-
headers["Accept"] = "image/jpeg" if image_format == "jpg" else "image/png"
66-
else:
67-
headers["Accept"] = "tensor/base64" if output_tensor_type == "base64" else "tensor/binary"
68-
elif output_type == "pt":
69-
headers["Accept"] = "tensor/base64" if output_tensor_type == "base64" else "tensor/binary"
70-
elif output_type == "mp4":
71-
headers["Accept"] = "text/plain"
72-
return headers
73-
74-
75-
def _prepare_parameters(
76-
tensor: "torch.Tensor",
77-
do_scaling: bool,
78-
output_type: Literal["mp4", "pil", "pt"],
79-
partial_postprocess: bool,
80-
height: Optional[int],
81-
width: Optional[int],
82-
) -> dict:
83-
params = {
84-
"do_scaling": do_scaling,
85-
"output_type": output_type,
86-
"partial_postprocess": partial_postprocess,
87-
"shape": list(tensor.shape),
88-
"dtype": str(tensor.dtype).split(".")[-1],
89-
}
90-
if height is not None and width is not None:
91-
params["height"] = height
92-
params["width"] = width
93-
return params
94-
95-
96-
def _encode_tensor_data(tensor: "torch.Tensor", input_tensor_type: Literal["base64", "binary"]) -> dict:
97-
tensor_data = safetensors.torch._tobytes(tensor, "tensor")
98-
if input_tensor_type == "base64":
99-
return {"json": {"inputs": base64.b64encode(tensor_data).decode("utf-8")}}
100-
return {"data": tensor_data}
101-
102-
103-
def _decode_tensor_response(response: requests.Response, output_tensor_type: Literal["base64", "binary"]):
104-
if output_tensor_type == "base64":
105-
content = response.json()
106-
tensor_bytes = base64.b64decode(content["inputs"])
107-
params = content["parameters"]
108-
else:
109-
tensor_bytes = response.content
110-
params = response.headers.copy()
111-
params["shape"] = json.loads(params["shape"])
112-
return tensor_bytes, params
113-
114-
115-
def _tensor_to_pil_images(tensor: "torch.Tensor") -> Union[Image.Image, List[Image.Image]]:
116-
# Assuming tensor is [batch, channels, height, width].
117-
images = [Image.fromarray((img.permute(1, 2, 0).cpu().numpy() * 255).round().astype("uint8")) for img in tensor]
118-
return images[0] if len(images) == 1 else images
57+
if do_scaling and scaling_factor is None:
58+
deprecate(
59+
"do_scaling",
60+
"1.0.0",
61+
"`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.",
62+
standard_warn=False,
63+
)
11964

12065

12166
def remote_decode(
12267
endpoint: str,
12368
tensor: "torch.Tensor",
12469
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
12570
do_scaling: bool = True,
71+
scaling_factor: Optional[float] = None,
72+
shift_factor: Optional[float] = None,
12673
output_type: Literal["mp4", "pil", "pt"] = "pil",
12774
return_type: Literal["mp4", "pil", "pt"] = "pil",
12875
image_format: Literal["png", "jpg"] = "jpg",
@@ -141,8 +88,16 @@ def remote_decode(
14188
processor (`VaeImageProcessor` or `VideoProcessor`, *optional*):
14289
Used with `return_type="pt"`, and `return_type="pil"` for Video models.
14390
do_scaling (`bool`, default `True`, *optional*):
144-
When `True` scaling e.g. `latents / self.vae.config.scaling_factor` is applied remotely. If `False`, input
91+
**DEPRECATED**. When `True` scaling e.g. `latents / self.vae.config.scaling_factor` is applied remotely. If `False`, input
14592
must be passed with scaling applied.
93+
scaling_factor (`float`, *optional*):
94+
Scaling is applied when passed e.g. `latents / self.vae.config.scaling_factor`.
95+
SD v1: 0.18215
96+
SD XL: 0.13025
97+
Flux: 0.3611
98+
shift_factor (`float`, *optional*):
99+
Shift is applied when passed e.g. `latents + self.vae.config.shift_factor`.
100+
Flux: 0.1159
146101
output_type (`"mp4"` or `"pil"` or `"pt", default `"pil"):
147102
**Endpoint** output type. Subject to change. Report feedback on preferred type.
148103
@@ -190,12 +145,13 @@ def remote_decode(
190145
width (`int`, **optional**):
191146
Required for `"packed"` latents.
192147
"""
193-
194148
check_inputs(
195149
endpoint,
196150
tensor,
197151
processor,
198152
do_scaling,
153+
scaling_factor,
154+
shift_factor,
199155
output_type,
200156
return_type,
201157
image_format,
@@ -205,48 +161,96 @@ def remote_decode(
205161
height,
206162
width,
207163
)
208-
209-
# Prepare request details.
210-
headers = _prepare_headers(input_tensor_type, output_type, image_format, processor, output_tensor_type)
211-
params = _prepare_parameters(tensor, do_scaling, output_type, partial_postprocess, height, width)
212-
payload = _encode_tensor_data(tensor, input_tensor_type)
213-
214-
response = requests.post(endpoint, params=params, headers=headers, **payload)
164+
headers = {}
165+
parameters = {
166+
"output_type": output_type,
167+
"partial_postprocess": partial_postprocess,
168+
"shape": list(tensor.shape),
169+
"dtype": str(tensor.dtype).split(".")[-1],
170+
}
171+
if do_scaling and scaling_factor is not None:
172+
parameters["scaling_factor"] = scaling_factor
173+
if do_scaling and shift_factor is not None:
174+
parameters["shift_factor"] = shift_factor
175+
if do_scaling and scaling_factor is None:
176+
parameters["do_scaling"] = do_scaling
177+
elif do_scaling and scaling_factor is None and shift_factor is None:
178+
parameters["do_scaling"] = do_scaling
179+
if height is not None and width is not None:
180+
parameters["height"] = height
181+
parameters["width"] = width
182+
tensor_data = safetensors.torch._tobytes(tensor, "tensor")
183+
if input_tensor_type == "base64":
184+
headers["Content-Type"] = "tensor/base64"
185+
elif input_tensor_type == "binary":
186+
headers["Content-Type"] = "tensor/binary"
187+
if output_type == "pil" and image_format == "jpg" and processor is None:
188+
headers["Accept"] = "image/jpeg"
189+
elif output_type == "pil" and image_format == "png" and processor is None:
190+
headers["Accept"] = "image/png"
191+
elif (output_tensor_type == "base64" and output_type == "pt") or (
192+
output_tensor_type == "base64" and output_type == "pil" and processor is not None
193+
):
194+
headers["Accept"] = "tensor/base64"
195+
elif (output_tensor_type == "binary" and output_type == "pt") or (
196+
output_tensor_type == "binary" and output_type == "pil" and processor is not None
197+
):
198+
headers["Accept"] = "tensor/binary"
199+
elif output_type == "mp4":
200+
headers["Accept"] = "text/plain"
201+
if input_tensor_type == "base64":
202+
kwargs = {"json": {"inputs": base64.b64encode(tensor_data).decode("utf-8")}}
203+
elif input_tensor_type == "binary":
204+
kwargs = {"data": tensor_data}
205+
response = requests.post(endpoint, params=parameters, **kwargs, headers=headers)
215206
if not response.ok:
216207
raise RuntimeError(response.json())
217-
218-
# Process responses that return a tensor.
219-
if output_type in ("pt",) or (output_type == "pil" and processor is not None):
220-
tensor_bytes, tensor_params = _decode_tensor_response(response, output_tensor_type)
221-
shape = tensor_params["shape"]
222-
dtype = tensor_params["dtype"]
208+
if output_type == "pt" or (output_type == "pil" and processor is not None):
209+
if output_tensor_type == "base64":
210+
content = response.json()
211+
output_tensor = base64.b64decode(content["inputs"])
212+
parameters = content["parameters"]
213+
shape = parameters["shape"]
214+
dtype = parameters["dtype"]
215+
elif output_tensor_type == "binary":
216+
output_tensor = response.content
217+
parameters = response.headers
218+
shape = json.loads(parameters["shape"])
219+
dtype = parameters["dtype"]
223220
torch_dtype = DTYPE_MAP[dtype]
224-
output_tensor = torch.frombuffer(bytearray(tensor_bytes), dtype=torch_dtype).reshape(shape)
225-
226-
if output_type == "pt":
227-
if partial_postprocess:
228-
if return_type == "pil":
229-
return _tensor_to_pil_images(output_tensor)
230-
return output_tensor
221+
output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
222+
if output_type == "pt":
223+
if partial_postprocess:
224+
if return_type == "pil":
225+
output = [Image.fromarray(image.numpy()) for image in output_tensor]
226+
if len(output) == 1:
227+
output = output[0]
228+
elif return_type == "pt":
229+
output = output_tensor
230+
else:
231+
if processor is None or return_type == "pt":
232+
output = output_tensor
231233
else:
232-
if processor is None or return_type == "pt":
233-
return output_tensor
234234
if isinstance(processor, VideoProcessor):
235-
return cast(List[Image.Image], processor.postprocess_video(output_tensor, output_type="pil")[0])
236-
return cast(Image.Image, processor.postprocess(output_tensor, output_type="pil")[0])
237-
238-
if output_type == "pil" and processor is None and return_type == "pil":
239-
return Image.open(io.BytesIO(response.content)).convert("RGB")
240-
241-
if output_type == "pil" and processor is not None:
242-
tensor_bytes, tensor_params = _decode_tensor_response(response, output_tensor_type)
243-
shape = tensor_params["shape"]
244-
dtype = tensor_params["dtype"]
245-
torch_dtype = DTYPE_MAP[dtype]
246-
output_tensor = torch.frombuffer(bytearray(tensor_bytes), dtype=torch_dtype).reshape(shape)
235+
output = cast(
236+
List[Image.Image],
237+
processor.postprocess_video(output_tensor, output_type="pil")[0],
238+
)
239+
else:
240+
output = cast(
241+
Image.Image,
242+
processor.postprocess(output_tensor, output_type="pil")[0],
243+
)
244+
elif output_type == "pil" and return_type == "pil" and processor is None:
245+
output = Image.open(io.BytesIO(response.content)).convert("RGB")
246+
elif output_type == "pil" and processor is not None:
247247
if return_type == "pil":
248-
return _tensor_to_pil_images(output_tensor)
249-
return output_tensor
250-
251-
if output_type == "mp4" and return_type == "mp4":
252-
return response.content
248+
output = [
249+
Image.fromarray(image)
250+
for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255).round().astype("uint8")
251+
]
252+
elif return_type == "pt":
253+
output = output_tensor
254+
elif output_type == "mp4" and return_type == "mp4":
255+
output = response.content
256+
return output

0 commit comments

Comments
 (0)