Skip to content

Commit 1978a8a

Browse files
committed
apply
1 parent 2c572f7 commit 1978a8a

File tree

1 file changed

+109
-84
lines changed

1 file changed

+109
-84
lines changed

src/diffusers/utils/remote_utils.py

Lines changed: 109 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import List, Literal, Optional, Union, cast
55

66
import requests
7+
from PIL import Image
78

89
from .import_utils import is_safetensors_available, is_torch_available
910

@@ -16,7 +17,6 @@
1617

1718
if is_safetensors_available():
1819
import safetensors
19-
2020
DTYPE_MAP = {
2121
"float16": torch.float16,
2222
"float32": torch.float32,
@@ -25,9 +25,6 @@
2525
}
2626

2727

28-
from PIL import Image
29-
30-
3128
def check_inputs(
3229
endpoint: str,
3330
tensor: "torch.Tensor",
@@ -53,6 +50,74 @@ def check_inputs(
5350
raise ValueError("`processor` is required.")
5451

5552

53+
def _prepare_headers(
54+
input_tensor_type: Literal["base64", "binary"],
55+
output_type: Literal["mp4", "pil", "pt"],
56+
image_format: Literal["png", "jpg"],
57+
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]],
58+
output_tensor_type: Literal["base64", "binary"],
59+
) -> dict:
60+
headers = {}
61+
headers["Content-Type"] = "tensor/base64" if input_tensor_type == "base64" else "tensor/binary"
62+
63+
if output_type == "pil":
64+
if processor is None:
65+
headers["Accept"] = "image/jpeg" if image_format == "jpg" else "image/png"
66+
else:
67+
headers["Accept"] = "tensor/base64" if output_tensor_type == "base64" else "tensor/binary"
68+
elif output_type == "pt":
69+
headers["Accept"] = "tensor/base64" if output_tensor_type == "base64" else "tensor/binary"
70+
elif output_type == "mp4":
71+
headers["Accept"] = "text/plain"
72+
return headers
73+
74+
75+
def _prepare_parameters(
76+
tensor: "torch.Tensor",
77+
do_scaling: bool,
78+
output_type: Literal["mp4", "pil", "pt"],
79+
partial_postprocess: bool,
80+
height: Optional[int],
81+
width: Optional[int],
82+
) -> dict:
83+
params = {
84+
"do_scaling": do_scaling,
85+
"output_type": output_type,
86+
"partial_postprocess": partial_postprocess,
87+
"shape": list(tensor.shape),
88+
"dtype": str(tensor.dtype).split(".")[-1],
89+
}
90+
if height is not None and width is not None:
91+
params["height"] = height
92+
params["width"] = width
93+
return params
94+
95+
96+
def _encode_tensor_data(tensor: "torch.Tensor", input_tensor_type: Literal["base64", "binary"]) -> dict:
97+
tensor_data = safetensors.torch._tobytes(tensor, "tensor")
98+
if input_tensor_type == "base64":
99+
return {"json": {"inputs": base64.b64encode(tensor_data).decode("utf-8")}}
100+
return {"data": tensor_data}
101+
102+
103+
def _decode_tensor_response(response: requests.Response, output_tensor_type: Literal["base64", "binary"]):
104+
if output_tensor_type == "base64":
105+
content = response.json()
106+
tensor_bytes = base64.b64decode(content["inputs"])
107+
params = content["parameters"]
108+
else:
109+
tensor_bytes = response.content
110+
params = response.headers.copy()
111+
params["shape"] = json.loads(params["shape"])
112+
return tensor_bytes, params
113+
114+
115+
def _tensor_to_pil_images(tensor: "torch.Tensor") -> Union[Image.Image, List[Image.Image]]:
116+
# Assuming tensor is [batch, channels, height, width].
117+
images = [Image.fromarray((img.permute(1, 2, 0).cpu().numpy() * 255).round().astype("uint8")) for img in tensor]
118+
return images[0] if len(images) == 1 else images
119+
120+
56121
def remote_decode(
57122
endpoint: str,
58123
tensor: "torch.Tensor",
@@ -125,6 +190,7 @@ def remote_decode(
125190
width (`int`, **optional**):
126191
Required for `"packed"` latents.
127192
"""
193+
128194
check_inputs(
129195
endpoint,
130196
tensor,
@@ -139,89 +205,48 @@ def remote_decode(
139205
height,
140206
width,
141207
)
142-
headers = {}
143-
parameters = {
144-
"do_scaling": do_scaling,
145-
"output_type": output_type,
146-
"partial_postprocess": partial_postprocess,
147-
"shape": list(tensor.shape),
148-
"dtype": str(tensor.dtype).split(".")[-1],
149-
}
150-
if height is not None and width is not None:
151-
parameters["height"] = height
152-
parameters["width"] = width
153-
tensor_data = safetensors.torch._tobytes(tensor, "tensor")
154-
if input_tensor_type == "base64":
155-
headers["Content-Type"] = "tensor/base64"
156-
elif input_tensor_type == "binary":
157-
headers["Content-Type"] = "tensor/binary"
158-
if output_type == "pil" and image_format == "jpg" and processor is None:
159-
headers["Accept"] = "image/jpeg"
160-
elif output_type == "pil" and image_format == "png" and processor is None:
161-
headers["Accept"] = "image/png"
162-
elif (output_tensor_type == "base64" and output_type == "pt") or (
163-
output_tensor_type == "base64" and output_type == "pil" and processor is not None
164-
):
165-
headers["Accept"] = "tensor/base64"
166-
elif (output_tensor_type == "binary" and output_type == "pt") or (
167-
output_tensor_type == "binary" and output_type == "pil" and processor is not None
168-
):
169-
headers["Accept"] = "tensor/binary"
170-
elif output_type == "mp4":
171-
headers["Accept"] = "text/plain"
172-
if input_tensor_type == "base64":
173-
kwargs = {"json": {"inputs": base64.b64encode(tensor_data).decode("utf-8")}}
174-
elif input_tensor_type == "binary":
175-
kwargs = {"data": tensor_data}
176-
response = requests.post(endpoint, params=parameters, **kwargs, headers=headers)
208+
209+
# Prepare request details.
210+
headers = _prepare_headers(input_tensor_type, output_type, image_format, processor, output_tensor_type)
211+
params = _prepare_parameters(tensor, do_scaling, output_type, partial_postprocess, height, width)
212+
payload = _encode_tensor_data(tensor, input_tensor_type)
213+
214+
response = requests.post(endpoint, params=params, headers=headers, **payload)
177215
if not response.ok:
178216
raise RuntimeError(response.json())
179-
if output_type == "pt" or (output_type == "pil" and processor is not None):
180-
if output_tensor_type == "base64":
181-
content = response.json()
182-
output_tensor = base64.b64decode(content["inputs"])
183-
parameters = content["parameters"]
184-
shape = parameters["shape"]
185-
dtype = parameters["dtype"]
186-
elif output_tensor_type == "binary":
187-
output_tensor = response.content
188-
parameters = response.headers
189-
shape = json.loads(parameters["shape"])
190-
dtype = parameters["dtype"]
217+
218+
# Process responses that return a tensor.
219+
if output_type in ("pt",) or (output_type == "pil" and processor is not None):
220+
tensor_bytes, tensor_params = _decode_tensor_response(response, output_tensor_type)
221+
shape = tensor_params["shape"]
222+
dtype = tensor_params["dtype"]
191223
torch_dtype = DTYPE_MAP[dtype]
192-
output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
193-
if output_type == "pt":
194-
if partial_postprocess:
195-
if return_type == "pil":
196-
output = [Image.fromarray(image.numpy()) for image in output_tensor]
197-
if len(output) == 1:
198-
output = output[0]
199-
elif return_type == "pt":
200-
output = output_tensor
201-
else:
202-
if processor is None or return_type == "pt":
203-
output = output_tensor
224+
output_tensor = torch.frombuffer(bytearray(tensor_bytes), dtype=torch_dtype).reshape(shape)
225+
226+
if output_type == "pt":
227+
if partial_postprocess:
228+
if return_type == "pil":
229+
return _tensor_to_pil_images(output_tensor)
230+
return output_tensor
204231
else:
232+
if processor is None or return_type == "pt":
233+
return output_tensor
205234
if isinstance(processor, VideoProcessor):
206-
output = cast(
207-
List[Image.Image],
208-
processor.postprocess_video(output_tensor, output_type="pil")[0],
209-
)
210-
else:
211-
output = cast(
212-
Image.Image,
213-
processor.postprocess(output_tensor, output_type="pil")[0],
214-
)
215-
elif output_type == "pil" and return_type == "pil" and processor is None:
216-
output = Image.open(io.BytesIO(response.content)).convert("RGB")
217-
elif output_type == "pil" and processor is not None:
235+
return cast(List[Image.Image], processor.postprocess_video(output_tensor, output_type="pil")[0])
236+
return cast(Image.Image, processor.postprocess(output_tensor, output_type="pil")[0])
237+
238+
if output_type == "pil" and processor is None and return_type == "pil":
239+
return Image.open(io.BytesIO(response.content)).convert("RGB")
240+
241+
if output_type == "pil" and processor is not None:
242+
tensor_bytes, tensor_params = _decode_tensor_response(response, output_tensor_type)
243+
shape = tensor_params["shape"]
244+
dtype = tensor_params["dtype"]
245+
torch_dtype = DTYPE_MAP[dtype]
246+
output_tensor = torch.frombuffer(bytearray(tensor_bytes), dtype=torch_dtype).reshape(shape)
218247
if return_type == "pil":
219-
output = [
220-
Image.fromarray(image)
221-
for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255).round().astype("uint8")
222-
]
223-
elif return_type == "pt":
224-
output = output_tensor
225-
elif output_type == "mp4" and return_type == "mp4":
226-
output = response.content
227-
return output
248+
return _tensor_to_pil_images(output_tensor)
249+
return output_tensor
250+
251+
if output_type == "mp4" and return_type == "mp4":
252+
return response.content

0 commit comments

Comments
 (0)