44from typing import List , Literal , Optional , Union , cast
55
66import requests
7+ from PIL import Image
78
89from .import_utils import is_safetensors_available , is_torch_available
910
1617
1718 if is_safetensors_available ():
1819 import safetensors
19-
2020 DTYPE_MAP = {
2121 "float16" : torch .float16 ,
2222 "float32" : torch .float32 ,
2525 }
2626
2727
28- from PIL import Image
29-
30-
3128def check_inputs (
3229 endpoint : str ,
3330 tensor : "torch.Tensor" ,
@@ -53,6 +50,74 @@ def check_inputs(
5350 raise ValueError ("`processor` is required." )
5451
5552
53+ def _prepare_headers (
54+ input_tensor_type : Literal ["base64" , "binary" ],
55+ output_type : Literal ["mp4" , "pil" , "pt" ],
56+ image_format : Literal ["png" , "jpg" ],
57+ processor : Optional [Union ["VaeImageProcessor" , "VideoProcessor" ]],
58+ output_tensor_type : Literal ["base64" , "binary" ],
59+ ) -> dict :
60+ headers = {}
61+ headers ["Content-Type" ] = "tensor/base64" if input_tensor_type == "base64" else "tensor/binary"
62+
63+ if output_type == "pil" :
64+ if processor is None :
65+ headers ["Accept" ] = "image/jpeg" if image_format == "jpg" else "image/png"
66+ else :
67+ headers ["Accept" ] = "tensor/base64" if output_tensor_type == "base64" else "tensor/binary"
68+ elif output_type == "pt" :
69+ headers ["Accept" ] = "tensor/base64" if output_tensor_type == "base64" else "tensor/binary"
70+ elif output_type == "mp4" :
71+ headers ["Accept" ] = "text/plain"
72+ return headers
73+
74+
75+ def _prepare_parameters (
76+ tensor : "torch.Tensor" ,
77+ do_scaling : bool ,
78+ output_type : Literal ["mp4" , "pil" , "pt" ],
79+ partial_postprocess : bool ,
80+ height : Optional [int ],
81+ width : Optional [int ],
82+ ) -> dict :
83+ params = {
84+ "do_scaling" : do_scaling ,
85+ "output_type" : output_type ,
86+ "partial_postprocess" : partial_postprocess ,
87+ "shape" : list (tensor .shape ),
88+ "dtype" : str (tensor .dtype ).split ("." )[- 1 ],
89+ }
90+ if height is not None and width is not None :
91+ params ["height" ] = height
92+ params ["width" ] = width
93+ return params
94+
95+
96+ def _encode_tensor_data (tensor : "torch.Tensor" , input_tensor_type : Literal ["base64" , "binary" ]) -> dict :
97+ tensor_data = safetensors .torch ._tobytes (tensor , "tensor" )
98+ if input_tensor_type == "base64" :
99+ return {"json" : {"inputs" : base64 .b64encode (tensor_data ).decode ("utf-8" )}}
100+ return {"data" : tensor_data }
101+
102+
103+ def _decode_tensor_response (response : requests .Response , output_tensor_type : Literal ["base64" , "binary" ]):
104+ if output_tensor_type == "base64" :
105+ content = response .json ()
106+ tensor_bytes = base64 .b64decode (content ["inputs" ])
107+ params = content ["parameters" ]
108+ else :
109+ tensor_bytes = response .content
110+ params = response .headers .copy ()
111+ params ["shape" ] = json .loads (params ["shape" ])
112+ return tensor_bytes , params
113+
114+
115+ def _tensor_to_pil_images (tensor : "torch.Tensor" ) -> Union [Image .Image , List [Image .Image ]]:
116+ # Assuming tensor is [batch, channels, height, width].
117+ images = [Image .fromarray ((img .permute (1 , 2 , 0 ).cpu ().numpy () * 255 ).round ().astype ("uint8" )) for img in tensor ]
118+ return images [0 ] if len (images ) == 1 else images
119+
120+
56121def remote_decode (
57122 endpoint : str ,
58123 tensor : "torch.Tensor" ,
@@ -125,6 +190,7 @@ def remote_decode(
125190 width (`int`, **optional**):
126191 Required for `"packed"` latents.
127192 """
193+
128194 check_inputs (
129195 endpoint ,
130196 tensor ,
@@ -139,89 +205,48 @@ def remote_decode(
139205 height ,
140206 width ,
141207 )
142- headers = {}
143- parameters = {
144- "do_scaling" : do_scaling ,
145- "output_type" : output_type ,
146- "partial_postprocess" : partial_postprocess ,
147- "shape" : list (tensor .shape ),
148- "dtype" : str (tensor .dtype ).split ("." )[- 1 ],
149- }
150- if height is not None and width is not None :
151- parameters ["height" ] = height
152- parameters ["width" ] = width
153- tensor_data = safetensors .torch ._tobytes (tensor , "tensor" )
154- if input_tensor_type == "base64" :
155- headers ["Content-Type" ] = "tensor/base64"
156- elif input_tensor_type == "binary" :
157- headers ["Content-Type" ] = "tensor/binary"
158- if output_type == "pil" and image_format == "jpg" and processor is None :
159- headers ["Accept" ] = "image/jpeg"
160- elif output_type == "pil" and image_format == "png" and processor is None :
161- headers ["Accept" ] = "image/png"
162- elif (output_tensor_type == "base64" and output_type == "pt" ) or (
163- output_tensor_type == "base64" and output_type == "pil" and processor is not None
164- ):
165- headers ["Accept" ] = "tensor/base64"
166- elif (output_tensor_type == "binary" and output_type == "pt" ) or (
167- output_tensor_type == "binary" and output_type == "pil" and processor is not None
168- ):
169- headers ["Accept" ] = "tensor/binary"
170- elif output_type == "mp4" :
171- headers ["Accept" ] = "text/plain"
172- if input_tensor_type == "base64" :
173- kwargs = {"json" : {"inputs" : base64 .b64encode (tensor_data ).decode ("utf-8" )}}
174- elif input_tensor_type == "binary" :
175- kwargs = {"data" : tensor_data }
176- response = requests .post (endpoint , params = parameters , ** kwargs , headers = headers )
208+
209+ # Prepare request details.
210+ headers = _prepare_headers (input_tensor_type , output_type , image_format , processor , output_tensor_type )
211+ params = _prepare_parameters (tensor , do_scaling , output_type , partial_postprocess , height , width )
212+ payload = _encode_tensor_data (tensor , input_tensor_type )
213+
214+ response = requests .post (endpoint , params = params , headers = headers , ** payload )
177215 if not response .ok :
178216 raise RuntimeError (response .json ())
179- if output_type == "pt" or (output_type == "pil" and processor is not None ):
180- if output_tensor_type == "base64" :
181- content = response .json ()
182- output_tensor = base64 .b64decode (content ["inputs" ])
183- parameters = content ["parameters" ]
184- shape = parameters ["shape" ]
185- dtype = parameters ["dtype" ]
186- elif output_tensor_type == "binary" :
187- output_tensor = response .content
188- parameters = response .headers
189- shape = json .loads (parameters ["shape" ])
190- dtype = parameters ["dtype" ]
217+
218+ # Process responses that return a tensor.
219+ if output_type in ("pt" ,) or (output_type == "pil" and processor is not None ):
220+ tensor_bytes , tensor_params = _decode_tensor_response (response , output_tensor_type )
221+ shape = tensor_params ["shape" ]
222+ dtype = tensor_params ["dtype" ]
191223 torch_dtype = DTYPE_MAP [dtype ]
192- output_tensor = torch .frombuffer (bytearray (output_tensor ), dtype = torch_dtype ).reshape (shape )
193- if output_type == "pt" :
194- if partial_postprocess :
195- if return_type == "pil" :
196- output = [Image .fromarray (image .numpy ()) for image in output_tensor ]
197- if len (output ) == 1 :
198- output = output [0 ]
199- elif return_type == "pt" :
200- output = output_tensor
201- else :
202- if processor is None or return_type == "pt" :
203- output = output_tensor
224+ output_tensor = torch .frombuffer (bytearray (tensor_bytes ), dtype = torch_dtype ).reshape (shape )
225+
226+ if output_type == "pt" :
227+ if partial_postprocess :
228+ if return_type == "pil" :
229+ return _tensor_to_pil_images (output_tensor )
230+ return output_tensor
204231 else :
232+ if processor is None or return_type == "pt" :
233+ return output_tensor
205234 if isinstance (processor , VideoProcessor ):
206- output = cast (
207- List [ Image .Image ],
208- processor . postprocess_video ( output_tensor , output_type = "pil" )[ 0 ],
209- )
210- else :
211- output = cast (
212- Image . Image ,
213- processor . postprocess ( output_tensor , output_type = "pil" )[ 0 ],
214- )
215- elif output_type == "pil" and return_type == "pil" and processor is None :
216- output = Image . open ( io . BytesIO ( response . content )). convert ( "RGB" )
217- elif output_type == "pil" and processor is not None :
235+ return cast (List [ Image . Image ], processor . postprocess_video ( output_tensor , output_type = "pil" )[ 0 ])
236+ return cast ( Image .Image , processor . postprocess ( output_tensor , output_type = "pil" )[ 0 ])
237+
238+ if output_type == "pil" and processor is None and return_type == "pil" :
239+ return Image . open ( io . BytesIO ( response . content )). convert ( "RGB" )
240+
241+ if output_type == "pil" and processor is not None :
242+ tensor_bytes , tensor_params = _decode_tensor_response ( response , output_tensor_type )
243+ shape = tensor_params [ "shape" ]
244+ dtype = tensor_params [ "dtype" ]
245+ torch_dtype = DTYPE_MAP [ dtype ]
246+ output_tensor = torch . frombuffer ( bytearray ( tensor_bytes ), dtype = torch_dtype ). reshape ( shape )
218247 if return_type == "pil" :
219- output = [
220- Image .fromarray (image )
221- for image in (output_tensor .permute (0 , 2 , 3 , 1 ).float ().numpy () * 255 ).round ().astype ("uint8" )
222- ]
223- elif return_type == "pt" :
224- output = output_tensor
225- elif output_type == "mp4" and return_type == "mp4" :
226- output = response .content
227- return output
248+ return _tensor_to_pil_images (output_tensor )
249+ return output_tensor
250+
251+ if output_type == "mp4" and return_type == "mp4" :
252+ return response .content
0 commit comments