Skip to content

Commit d80d66c

Browse files
committed
address comments
1 parent 2af1995 commit d80d66c

File tree

1 file changed

+109
-63
lines changed

1 file changed

+109
-63
lines changed

src/diffusers/utils/remote_utils.py

Lines changed: 109 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,91 @@ def check_inputs(
6262
)
6363

6464

65+
def postprocess(
66+
response: requests.Response,
67+
processor: Optional[Union["VaeImageProcessor", "VideoProcessor"]] = None,
68+
output_type: Literal["mp4", "pil", "pt"] = "pil",
69+
return_type: Literal["mp4", "pil", "pt"] = "pil",
70+
partial_postprocess: bool = False,
71+
):
72+
if output_type == "pt" or (output_type == "pil" and processor is not None):
73+
output_tensor = response.content
74+
parameters = response.headers
75+
shape = json.loads(parameters["shape"])
76+
dtype = parameters["dtype"]
77+
torch_dtype = DTYPE_MAP[dtype]
78+
output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
79+
if output_type == "pt":
80+
if partial_postprocess:
81+
if return_type == "pil":
82+
output = [Image.fromarray(image.numpy()) for image in output_tensor]
83+
if len(output) == 1:
84+
output = output[0]
85+
elif return_type == "pt":
86+
output = output_tensor
87+
else:
88+
if processor is None or return_type == "pt":
89+
output = output_tensor
90+
else:
91+
if isinstance(processor, VideoProcessor):
92+
output = cast(
93+
List[Image.Image],
94+
processor.postprocess_video(output_tensor, output_type="pil")[0],
95+
)
96+
else:
97+
output = cast(
98+
Image.Image,
99+
processor.postprocess(output_tensor, output_type="pil")[0],
100+
)
101+
elif output_type == "pil" and return_type == "pil" and processor is None:
102+
output = Image.open(io.BytesIO(response.content)).convert("RGB")
103+
elif output_type == "pil" and processor is not None:
104+
if return_type == "pil":
105+
output = [
106+
Image.fromarray(image)
107+
for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255).round().astype("uint8")
108+
]
109+
elif return_type == "pt":
110+
output = output_tensor
111+
elif output_type == "mp4" and return_type == "mp4":
112+
output = response.content
113+
return output
114+
115+
116+
def prepare(
117+
tensor: "torch.Tensor",
118+
do_scaling: bool = True,
119+
scaling_factor: Optional[float] = None,
120+
shift_factor: Optional[float] = None,
121+
output_type: Literal["mp4", "pil", "pt"] = "pil",
122+
image_format: Literal["png", "jpg"] = "jpg",
123+
partial_postprocess: bool = False,
124+
height: Optional[int] = None,
125+
width: Optional[int] = None,
126+
):
127+
headers = {}
128+
parameters = {
129+
"image_format": image_format,
130+
"output_type": output_type,
131+
"partial_postprocess": partial_postprocess,
132+
"shape": list(tensor.shape),
133+
"dtype": str(tensor.dtype).split(".")[-1],
134+
}
135+
if do_scaling and scaling_factor is not None:
136+
parameters["scaling_factor"] = scaling_factor
137+
if do_scaling and shift_factor is not None:
138+
parameters["shift_factor"] = shift_factor
139+
if do_scaling and scaling_factor is None:
140+
parameters["do_scaling"] = do_scaling
141+
elif do_scaling and scaling_factor is None and shift_factor is None:
142+
parameters["do_scaling"] = do_scaling
143+
if height is not None and width is not None:
144+
parameters["height"] = height
145+
parameters["width"] = width
146+
tensor_data = safetensors.torch._tobytes(tensor, "tensor")
147+
return {"data": tensor_data, "params": parameters, "headers": headers}
148+
149+
65150
def remote_decode(
66151
endpoint: str,
67152
tensor: "torch.Tensor",
@@ -79,6 +164,8 @@ def remote_decode(
79164
width: Optional[int] = None,
80165
) -> Union[Image.Image, List[Image.Image], bytes, "torch.Tensor"]:
81166
"""
167+
Hugging Face Hybrid Inference
168+
82169
Args:
83170
endpoint (`str`):
84171
Endpoint for Remote Decode.
@@ -140,6 +227,9 @@ def remote_decode(
140227
141228
width (`int`, **optional**):
142229
Required for `"packed"` latents.
230+
231+
Returns:
232+
output (`Image.Image` or `List[Image.Image]` or `bytes` or `torch.Tensor`).
143233
"""
144234
if input_tensor_type == "base64":
145235
deprecate(
@@ -173,69 +263,25 @@ def remote_decode(
173263
height,
174264
width,
175265
)
176-
headers = {}
177-
parameters = {
178-
"image_format": image_format,
179-
"output_type": output_type,
180-
"partial_postprocess": partial_postprocess,
181-
"shape": list(tensor.shape),
182-
"dtype": str(tensor.dtype).split(".")[-1],
183-
}
184-
if do_scaling and scaling_factor is not None:
185-
parameters["scaling_factor"] = scaling_factor
186-
if do_scaling and shift_factor is not None:
187-
parameters["shift_factor"] = shift_factor
188-
if do_scaling and scaling_factor is None:
189-
parameters["do_scaling"] = do_scaling
190-
elif do_scaling and scaling_factor is None and shift_factor is None:
191-
parameters["do_scaling"] = do_scaling
192-
if height is not None and width is not None:
193-
parameters["height"] = height
194-
parameters["width"] = width
195-
tensor_data = safetensors.torch._tobytes(tensor, "tensor")
196-
kwargs = {"data": tensor_data}
197-
response = requests.post(endpoint, params=parameters, **kwargs, headers=headers)
266+
kwargs = prepare(
267+
tensor=tensor,
268+
do_scaling=do_scaling,
269+
scaling_factor=scaling_factor,
270+
shift_factor=shift_factor,
271+
output_type=output_type,
272+
image_format=image_format,
273+
partial_postprocess=partial_postprocess,
274+
height=height,
275+
width=width,
276+
)
277+
response = requests.post(endpoint, **kwargs)
198278
if not response.ok:
199279
raise RuntimeError(response.json())
200-
if output_type == "pt" or (output_type == "pil" and processor is not None):
201-
output_tensor = response.content
202-
parameters = response.headers
203-
shape = json.loads(parameters["shape"])
204-
dtype = parameters["dtype"]
205-
torch_dtype = DTYPE_MAP[dtype]
206-
output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
207-
if output_type == "pt":
208-
if partial_postprocess:
209-
if return_type == "pil":
210-
output = [Image.fromarray(image.numpy()) for image in output_tensor]
211-
if len(output) == 1:
212-
output = output[0]
213-
elif return_type == "pt":
214-
output = output_tensor
215-
else:
216-
if processor is None or return_type == "pt":
217-
output = output_tensor
218-
else:
219-
if isinstance(processor, VideoProcessor):
220-
output = cast(
221-
List[Image.Image],
222-
processor.postprocess_video(output_tensor, output_type="pil")[0],
223-
)
224-
else:
225-
output = cast(
226-
Image.Image,
227-
processor.postprocess(output_tensor, output_type="pil")[0],
228-
)
229-
elif output_type == "pil" and return_type == "pil" and processor is None:
230-
output = Image.open(io.BytesIO(response.content)).convert("RGB")
231-
elif output_type == "pil" and processor is not None:
232-
if return_type == "pil":
233-
output = [
234-
Image.fromarray(image)
235-
for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255).round().astype("uint8")
236-
]
237-
elif return_type == "pt":
238-
output = output_tensor
239-
elif output_type == "mp4" and return_type == "mp4":
240-
output = response.content
280+
output = postprocess(
281+
response=response,
282+
processor=processor,
283+
output_type=output_type,
284+
return_type=return_type,
285+
partial_postprocess=partial_postprocess,
286+
)
241287
return output

0 commit comments

Comments
 (0)