@@ -62,6 +62,91 @@ def check_inputs(
6262 )
6363
6464
65+ def postprocess (
66+ response : requests .Response ,
67+ processor : Optional [Union ["VaeImageProcessor" , "VideoProcessor" ]] = None ,
68+ output_type : Literal ["mp4" , "pil" , "pt" ] = "pil" ,
69+ return_type : Literal ["mp4" , "pil" , "pt" ] = "pil" ,
70+ partial_postprocess : bool = False ,
71+ ):
72+ if output_type == "pt" or (output_type == "pil" and processor is not None ):
73+ output_tensor = response .content
74+ parameters = response .headers
75+ shape = json .loads (parameters ["shape" ])
76+ dtype = parameters ["dtype" ]
77+ torch_dtype = DTYPE_MAP [dtype ]
78+ output_tensor = torch .frombuffer (bytearray (output_tensor ), dtype = torch_dtype ).reshape (shape )
79+ if output_type == "pt" :
80+ if partial_postprocess :
81+ if return_type == "pil" :
82+ output = [Image .fromarray (image .numpy ()) for image in output_tensor ]
83+ if len (output ) == 1 :
84+ output = output [0 ]
85+ elif return_type == "pt" :
86+ output = output_tensor
87+ else :
88+ if processor is None or return_type == "pt" :
89+ output = output_tensor
90+ else :
91+ if isinstance (processor , VideoProcessor ):
92+ output = cast (
93+ List [Image .Image ],
94+ processor .postprocess_video (output_tensor , output_type = "pil" )[0 ],
95+ )
96+ else :
97+ output = cast (
98+ Image .Image ,
99+ processor .postprocess (output_tensor , output_type = "pil" )[0 ],
100+ )
101+ elif output_type == "pil" and return_type == "pil" and processor is None :
102+ output = Image .open (io .BytesIO (response .content )).convert ("RGB" )
103+ elif output_type == "pil" and processor is not None :
104+ if return_type == "pil" :
105+ output = [
106+ Image .fromarray (image )
107+ for image in (output_tensor .permute (0 , 2 , 3 , 1 ).float ().numpy () * 255 ).round ().astype ("uint8" )
108+ ]
109+ elif return_type == "pt" :
110+ output = output_tensor
111+ elif output_type == "mp4" and return_type == "mp4" :
112+ output = response .content
113+ return output
114+
115+
116+ def prepare (
117+ tensor : "torch.Tensor" ,
118+ do_scaling : bool = True ,
119+ scaling_factor : Optional [float ] = None ,
120+ shift_factor : Optional [float ] = None ,
121+ output_type : Literal ["mp4" , "pil" , "pt" ] = "pil" ,
122+ image_format : Literal ["png" , "jpg" ] = "jpg" ,
123+ partial_postprocess : bool = False ,
124+ height : Optional [int ] = None ,
125+ width : Optional [int ] = None ,
126+ ):
127+ headers = {}
128+ parameters = {
129+ "image_format" : image_format ,
130+ "output_type" : output_type ,
131+ "partial_postprocess" : partial_postprocess ,
132+ "shape" : list (tensor .shape ),
133+ "dtype" : str (tensor .dtype ).split ("." )[- 1 ],
134+ }
135+ if do_scaling and scaling_factor is not None :
136+ parameters ["scaling_factor" ] = scaling_factor
137+ if do_scaling and shift_factor is not None :
138+ parameters ["shift_factor" ] = shift_factor
139+ if do_scaling and scaling_factor is None :
140+ parameters ["do_scaling" ] = do_scaling
141+ elif do_scaling and scaling_factor is None and shift_factor is None :
142+ parameters ["do_scaling" ] = do_scaling
143+ if height is not None and width is not None :
144+ parameters ["height" ] = height
145+ parameters ["width" ] = width
146+ tensor_data = safetensors .torch ._tobytes (tensor , "tensor" )
147+ return {"data" : tensor_data , "params" : parameters , "headers" : headers }
148+
149+
65150def remote_decode (
66151 endpoint : str ,
67152 tensor : "torch.Tensor" ,
@@ -79,6 +164,8 @@ def remote_decode(
79164 width : Optional [int ] = None ,
80165) -> Union [Image .Image , List [Image .Image ], bytes , "torch.Tensor" ]:
81166 """
167+ Hugging Face Hybrid Inference
168+
82169 Args:
83170 endpoint (`str`):
84171 Endpoint for Remote Decode.
@@ -140,6 +227,9 @@ def remote_decode(
140227
141228 width (`int`, **optional**):
142229 Required for `"packed"` latents.
230+
231+ Returns:
232+ output (`Image.Image` or `List[Image.Image]` or `bytes` or `torch.Tensor`).
143233 """
144234 if input_tensor_type == "base64" :
145235 deprecate (
@@ -173,69 +263,25 @@ def remote_decode(
173263 height ,
174264 width ,
175265 )
176- headers = {}
177- parameters = {
178- "image_format" : image_format ,
179- "output_type" : output_type ,
180- "partial_postprocess" : partial_postprocess ,
181- "shape" : list (tensor .shape ),
182- "dtype" : str (tensor .dtype ).split ("." )[- 1 ],
183- }
184- if do_scaling and scaling_factor is not None :
185- parameters ["scaling_factor" ] = scaling_factor
186- if do_scaling and shift_factor is not None :
187- parameters ["shift_factor" ] = shift_factor
188- if do_scaling and scaling_factor is None :
189- parameters ["do_scaling" ] = do_scaling
190- elif do_scaling and scaling_factor is None and shift_factor is None :
191- parameters ["do_scaling" ] = do_scaling
192- if height is not None and width is not None :
193- parameters ["height" ] = height
194- parameters ["width" ] = width
195- tensor_data = safetensors .torch ._tobytes (tensor , "tensor" )
196- kwargs = {"data" : tensor_data }
197- response = requests .post (endpoint , params = parameters , ** kwargs , headers = headers )
266+ kwargs = prepare (
267+ tensor = tensor ,
268+ do_scaling = do_scaling ,
269+ scaling_factor = scaling_factor ,
270+ shift_factor = shift_factor ,
271+ output_type = output_type ,
272+ image_format = image_format ,
273+ partial_postprocess = partial_postprocess ,
274+ height = height ,
275+ width = width ,
276+ )
277+ response = requests .post (endpoint , ** kwargs )
198278 if not response .ok :
199279 raise RuntimeError (response .json ())
200- if output_type == "pt" or (output_type == "pil" and processor is not None ):
201- output_tensor = response .content
202- parameters = response .headers
203- shape = json .loads (parameters ["shape" ])
204- dtype = parameters ["dtype" ]
205- torch_dtype = DTYPE_MAP [dtype ]
206- output_tensor = torch .frombuffer (bytearray (output_tensor ), dtype = torch_dtype ).reshape (shape )
207- if output_type == "pt" :
208- if partial_postprocess :
209- if return_type == "pil" :
210- output = [Image .fromarray (image .numpy ()) for image in output_tensor ]
211- if len (output ) == 1 :
212- output = output [0 ]
213- elif return_type == "pt" :
214- output = output_tensor
215- else :
216- if processor is None or return_type == "pt" :
217- output = output_tensor
218- else :
219- if isinstance (processor , VideoProcessor ):
220- output = cast (
221- List [Image .Image ],
222- processor .postprocess_video (output_tensor , output_type = "pil" )[0 ],
223- )
224- else :
225- output = cast (
226- Image .Image ,
227- processor .postprocess (output_tensor , output_type = "pil" )[0 ],
228- )
229- elif output_type == "pil" and return_type == "pil" and processor is None :
230- output = Image .open (io .BytesIO (response .content )).convert ("RGB" )
231- elif output_type == "pil" and processor is not None :
232- if return_type == "pil" :
233- output = [
234- Image .fromarray (image )
235- for image in (output_tensor .permute (0 , 2 , 3 , 1 ).float ().numpy () * 255 ).round ().astype ("uint8" )
236- ]
237- elif return_type == "pt" :
238- output = output_tensor
239- elif output_type == "mp4" and return_type == "mp4" :
240- output = response .content
280+ output = postprocess (
281+ response = response ,
282+ processor = processor ,
283+ output_type = output_type ,
284+ return_type = return_type ,
285+ partial_postprocess = partial_postprocess ,
286+ )
241287 return output
0 commit comments