1- import base64
21import io
32import json
43from typing import List , Literal , Optional , Union , cast
@@ -40,8 +39,8 @@ def check_inputs(
4039 return_type : Literal ["mp4" , "pil" , "pt" ] = "pil" ,
4140 image_format : Literal ["png" , "jpg" ] = "jpg" ,
4241 partial_postprocess : bool = False ,
43- input_tensor_type : Literal ["base64" , " binary" ] = "base64 " ,
44- output_tensor_type : Literal ["base64" , " binary" ] = "base64 " ,
42+ input_tensor_type : Literal ["binary" ] = "binary " ,
43+ output_tensor_type : Literal ["binary" ] = "binary " ,
4544 height : Optional [int ] = None ,
4645 width : Optional [int ] = None ,
4746):
@@ -74,8 +73,8 @@ def remote_decode(
7473 return_type : Literal ["mp4" , "pil" , "pt" ] = "pil" ,
7574 image_format : Literal ["png" , "jpg" ] = "jpg" ,
7675 partial_postprocess : bool = False ,
77- input_tensor_type : Literal ["base64" , " binary" ] = "base64 " ,
78- output_tensor_type : Literal ["base64" , " binary" ] = "base64 " ,
76+ input_tensor_type : Literal ["binary" ] = "binary " ,
77+ output_tensor_type : Literal ["binary" ] = "binary " ,
7978 height : Optional [int ] = None ,
8079 width : Optional [int ] = None ,
8180) -> Union [Image .Image , List [Image .Image ], bytes , "torch.Tensor" ]:
@@ -130,18 +129,34 @@ def remote_decode(
130129 Used with `output_type="pt"`. `partial_postprocess=False` tensor is `float16` or `bfloat16`, without
131130 denormalization. `partial_postprocess=True` tensor is `uint8`, denormalized.
132131
133- input_tensor_type (`"base64"` or `" binary"`, default `"base64 "`):
134- With `"base64"` `tensor` is sent to endpoint base64 encoded. `"binary"` reduces overhead and transfer.
132+ input_tensor_type (`"binary"`, default `"binary "`):
133+ Tensor transfer type .
135134
136- output_tensor_type (`"base64"` or `" binary"`, default `"base64 "`):
137- With `"base64"` `tensor` returned by endpoint is base64 encoded. `"binary"` reduces overhead and transfer.
135+ output_tensor_type (`"binary"`, default `"binary "`):
136+ Tensor transfer type .
138137
139138 height (`int`, **optional**):
140139 Required for `"packed"` latents.
141140
142141 width (`int`, **optional**):
143142 Required for `"packed"` latents.
144143 """
144+ if input_tensor_type == "base64" :
145+ deprecate (
146+ "input_tensor_type='base64'" ,
147+ "1.0.0" ,
148+ "input_tensor_type='base64' is deprecated. Using `binary`." ,
149+ standard_warn = False ,
150+ )
151+ input_tensor_type = "binary"
152+ if output_tensor_type == "base64" :
153+ deprecate (
154+ "output_tensor_type='base64'" ,
155+ "1.0.0" ,
156+ "output_tensor_type='base64' is deprecated. Using `binary`." ,
157+ standard_warn = False ,
158+ )
159+ output_tensor_type = "binary"
145160 check_inputs (
146161 endpoint ,
147162 tensor ,
@@ -160,6 +175,7 @@ def remote_decode(
160175 )
161176 headers = {}
162177 parameters = {
178+ "image_format" : image_format ,
163179 "output_type" : output_type ,
164180 "partial_postprocess" : partial_postprocess ,
165181 "shape" : list (tensor .shape ),
@@ -177,43 +193,15 @@ def remote_decode(
177193 parameters ["height" ] = height
178194 parameters ["width" ] = width
179195 tensor_data = safetensors .torch ._tobytes (tensor , "tensor" )
180- if input_tensor_type == "base64" :
181- headers ["Content-Type" ] = "tensor/base64"
182- elif input_tensor_type == "binary" :
183- headers ["Content-Type" ] = "tensor/binary"
184- if output_type == "pil" and image_format == "jpg" and processor is None :
185- headers ["Accept" ] = "image/jpeg"
186- elif output_type == "pil" and image_format == "png" and processor is None :
187- headers ["Accept" ] = "image/png"
188- elif (output_tensor_type == "base64" and output_type == "pt" ) or (
189- output_tensor_type == "base64" and output_type == "pil" and processor is not None
190- ):
191- headers ["Accept" ] = "tensor/base64"
192- elif (output_tensor_type == "binary" and output_type == "pt" ) or (
193- output_tensor_type == "binary" and output_type == "pil" and processor is not None
194- ):
195- headers ["Accept" ] = "tensor/binary"
196- elif output_type == "mp4" :
197- headers ["Accept" ] = "text/plain"
198- if input_tensor_type == "base64" :
199- kwargs = {"json" : {"inputs" : base64 .b64encode (tensor_data ).decode ("utf-8" )}}
200- elif input_tensor_type == "binary" :
201- kwargs = {"data" : tensor_data }
196+ kwargs = {"data" : tensor_data }
202197 response = requests .post (endpoint , params = parameters , ** kwargs , headers = headers )
203198 if not response .ok :
204199 raise RuntimeError (response .json ())
205200 if output_type == "pt" or (output_type == "pil" and processor is not None ):
206- if output_tensor_type == "base64" :
207- content = response .json ()
208- output_tensor = base64 .b64decode (content ["inputs" ])
209- parameters = content ["parameters" ]
210- shape = parameters ["shape" ]
211- dtype = parameters ["dtype" ]
212- elif output_tensor_type == "binary" :
213- output_tensor = response .content
214- parameters = response .headers
215- shape = json .loads (parameters ["shape" ])
216- dtype = parameters ["dtype" ]
201+ output_tensor = response .content
202+ parameters = response .headers
203+ shape = json .loads (parameters ["shape" ])
204+ dtype = parameters ["dtype" ]
217205 torch_dtype = DTYPE_MAP [dtype ]
218206 output_tensor = torch .frombuffer (bytearray (output_tensor ), dtype = torch_dtype ).reshape (shape )
219207 if output_type == "pt" :
0 commit comments