Skip to content

Commit 2af1995

Browse files
committed
deprecate base64, headers not needed
1 parent 54280dd commit 2af1995

File tree

1 file changed

+30
-42
lines changed

1 file changed

+30
-42
lines changed

src/diffusers/utils/remote_utils.py

Lines changed: 30 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import base64
21
import io
32
import json
43
from typing import List, Literal, Optional, Union, cast
@@ -40,8 +39,8 @@ def check_inputs(
4039
return_type: Literal["mp4", "pil", "pt"] = "pil",
4140
image_format: Literal["png", "jpg"] = "jpg",
4241
partial_postprocess: bool = False,
43-
input_tensor_type: Literal["base64", "binary"] = "base64",
44-
output_tensor_type: Literal["base64", "binary"] = "base64",
42+
input_tensor_type: Literal["binary"] = "binary",
43+
output_tensor_type: Literal["binary"] = "binary",
4544
height: Optional[int] = None,
4645
width: Optional[int] = None,
4746
):
@@ -74,8 +73,8 @@ def remote_decode(
7473
return_type: Literal["mp4", "pil", "pt"] = "pil",
7574
image_format: Literal["png", "jpg"] = "jpg",
7675
partial_postprocess: bool = False,
77-
input_tensor_type: Literal["base64", "binary"] = "base64",
78-
output_tensor_type: Literal["base64", "binary"] = "base64",
76+
input_tensor_type: Literal["binary"] = "binary",
77+
output_tensor_type: Literal["binary"] = "binary",
7978
height: Optional[int] = None,
8079
width: Optional[int] = None,
8180
) -> Union[Image.Image, List[Image.Image], bytes, "torch.Tensor"]:
@@ -130,18 +129,34 @@ def remote_decode(
130129
Used with `output_type="pt"`. `partial_postprocess=False` tensor is `float16` or `bfloat16`, without
131130
denormalization. `partial_postprocess=True` tensor is `uint8`, denormalized.
132131
133-
input_tensor_type (`"base64"` or `"binary"`, default `"base64"`):
134-
With `"base64"` `tensor` is sent to endpoint base64 encoded. `"binary"` reduces overhead and transfer.
132+
input_tensor_type (`"binary"`, default `"binary"`):
133+
Tensor transfer type.
135134
136-
output_tensor_type (`"base64"` or `"binary"`, default `"base64"`):
137-
With `"base64"` `tensor` returned by endpoint is base64 encoded. `"binary"` reduces overhead and transfer.
135+
output_tensor_type (`"binary"`, default `"binary"`):
136+
Tensor transfer type.
138137
139138
height (`int`, **optional**):
140139
Required for `"packed"` latents.
141140
142141
width (`int`, **optional**):
143142
Required for `"packed"` latents.
144143
"""
144+
if input_tensor_type == "base64":
145+
deprecate(
146+
"input_tensor_type='base64'",
147+
"1.0.0",
148+
"input_tensor_type='base64' is deprecated. Using `binary`.",
149+
standard_warn=False,
150+
)
151+
input_tensor_type = "binary"
152+
if output_tensor_type == "base64":
153+
deprecate(
154+
"output_tensor_type='base64'",
155+
"1.0.0",
156+
"output_tensor_type='base64' is deprecated. Using `binary`.",
157+
standard_warn=False,
158+
)
159+
output_tensor_type = "binary"
145160
check_inputs(
146161
endpoint,
147162
tensor,
@@ -160,6 +175,7 @@ def remote_decode(
160175
)
161176
headers = {}
162177
parameters = {
178+
"image_format": image_format,
163179
"output_type": output_type,
164180
"partial_postprocess": partial_postprocess,
165181
"shape": list(tensor.shape),
@@ -177,43 +193,15 @@ def remote_decode(
177193
parameters["height"] = height
178194
parameters["width"] = width
179195
tensor_data = safetensors.torch._tobytes(tensor, "tensor")
180-
if input_tensor_type == "base64":
181-
headers["Content-Type"] = "tensor/base64"
182-
elif input_tensor_type == "binary":
183-
headers["Content-Type"] = "tensor/binary"
184-
if output_type == "pil" and image_format == "jpg" and processor is None:
185-
headers["Accept"] = "image/jpeg"
186-
elif output_type == "pil" and image_format == "png" and processor is None:
187-
headers["Accept"] = "image/png"
188-
elif (output_tensor_type == "base64" and output_type == "pt") or (
189-
output_tensor_type == "base64" and output_type == "pil" and processor is not None
190-
):
191-
headers["Accept"] = "tensor/base64"
192-
elif (output_tensor_type == "binary" and output_type == "pt") or (
193-
output_tensor_type == "binary" and output_type == "pil" and processor is not None
194-
):
195-
headers["Accept"] = "tensor/binary"
196-
elif output_type == "mp4":
197-
headers["Accept"] = "text/plain"
198-
if input_tensor_type == "base64":
199-
kwargs = {"json": {"inputs": base64.b64encode(tensor_data).decode("utf-8")}}
200-
elif input_tensor_type == "binary":
201-
kwargs = {"data": tensor_data}
196+
kwargs = {"data": tensor_data}
202197
response = requests.post(endpoint, params=parameters, **kwargs, headers=headers)
203198
if not response.ok:
204199
raise RuntimeError(response.json())
205200
if output_type == "pt" or (output_type == "pil" and processor is not None):
206-
if output_tensor_type == "base64":
207-
content = response.json()
208-
output_tensor = base64.b64decode(content["inputs"])
209-
parameters = content["parameters"]
210-
shape = parameters["shape"]
211-
dtype = parameters["dtype"]
212-
elif output_tensor_type == "binary":
213-
output_tensor = response.content
214-
parameters = response.headers
215-
shape = json.loads(parameters["shape"])
216-
dtype = parameters["dtype"]
201+
output_tensor = response.content
202+
parameters = response.headers
203+
shape = json.loads(parameters["shape"])
204+
dtype = parameters["dtype"]
217205
torch_dtype = DTYPE_MAP[dtype]
218206
output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape)
219207
if output_type == "pt":

0 commit comments

Comments
 (0)