44from typing import List , Literal , Optional , Union , cast
55
66import requests
7- from PIL import Image
87
8+ from .deprecation_utils import deprecate
99from .import_utils import is_safetensors_available , is_torch_available
1010
1111
1616 from ..video_processor import VideoProcessor
1717
1818 if is_safetensors_available ():
19- import safetensors
19+ import safetensors .torch
20+
2021 DTYPE_MAP = {
2122 "float16" : torch .float16 ,
2223 "float32" : torch .float32 ,
2526 }
2627
2728
29+ from PIL import Image
30+
31+
2832def check_inputs (
2933 endpoint : str ,
3034 tensor : "torch.Tensor" ,
3135 processor : Optional [Union ["VaeImageProcessor" , "VideoProcessor" ]] = None ,
3236 do_scaling : bool = True ,
37+ scaling_factor : Optional [float ] = None ,
38+ shift_factor : Optional [float ] = None ,
3339 output_type : Literal ["mp4" , "pil" , "pt" ] = "pil" ,
3440 return_type : Literal ["mp4" , "pil" , "pt" ] = "pil" ,
3541 image_format : Literal ["png" , "jpg" ] = "jpg" ,
@@ -48,81 +54,22 @@ def check_inputs(
4854 and not isinstance (processor , (VaeImageProcessor , VideoProcessor ))
4955 ):
5056 raise ValueError ("`processor` is required." )
51-
52-
53- def _prepare_headers (
54- input_tensor_type : Literal ["base64" , "binary" ],
55- output_type : Literal ["mp4" , "pil" , "pt" ],
56- image_format : Literal ["png" , "jpg" ],
57- processor : Optional [Union ["VaeImageProcessor" , "VideoProcessor" ]],
58- output_tensor_type : Literal ["base64" , "binary" ],
59- ) -> dict :
60- headers = {}
61- headers ["Content-Type" ] = "tensor/base64" if input_tensor_type == "base64" else "tensor/binary"
62-
63- if output_type == "pil" :
64- if processor is None :
65- headers ["Accept" ] = "image/jpeg" if image_format == "jpg" else "image/png"
66- else :
67- headers ["Accept" ] = "tensor/base64" if output_tensor_type == "base64" else "tensor/binary"
68- elif output_type == "pt" :
69- headers ["Accept" ] = "tensor/base64" if output_tensor_type == "base64" else "tensor/binary"
70- elif output_type == "mp4" :
71- headers ["Accept" ] = "text/plain"
72- return headers
73-
74-
75- def _prepare_parameters (
76- tensor : "torch.Tensor" ,
77- do_scaling : bool ,
78- output_type : Literal ["mp4" , "pil" , "pt" ],
79- partial_postprocess : bool ,
80- height : Optional [int ],
81- width : Optional [int ],
82- ) -> dict :
83- params = {
84- "do_scaling" : do_scaling ,
85- "output_type" : output_type ,
86- "partial_postprocess" : partial_postprocess ,
87- "shape" : list (tensor .shape ),
88- "dtype" : str (tensor .dtype ).split ("." )[- 1 ],
89- }
90- if height is not None and width is not None :
91- params ["height" ] = height
92- params ["width" ] = width
93- return params
94-
95-
96- def _encode_tensor_data (tensor : "torch.Tensor" , input_tensor_type : Literal ["base64" , "binary" ]) -> dict :
97- tensor_data = safetensors .torch ._tobytes (tensor , "tensor" )
98- if input_tensor_type == "base64" :
99- return {"json" : {"inputs" : base64 .b64encode (tensor_data ).decode ("utf-8" )}}
100- return {"data" : tensor_data }
101-
102-
103- def _decode_tensor_response (response : requests .Response , output_tensor_type : Literal ["base64" , "binary" ]):
104- if output_tensor_type == "base64" :
105- content = response .json ()
106- tensor_bytes = base64 .b64decode (content ["inputs" ])
107- params = content ["parameters" ]
108- else :
109- tensor_bytes = response .content
110- params = response .headers .copy ()
111- params ["shape" ] = json .loads (params ["shape" ])
112- return tensor_bytes , params
113-
114-
115- def _tensor_to_pil_images (tensor : "torch.Tensor" ) -> Union [Image .Image , List [Image .Image ]]:
116- # Assuming tensor is [batch, channels, height, width].
117- images = [Image .fromarray ((img .permute (1 , 2 , 0 ).cpu ().numpy () * 255 ).round ().astype ("uint8" )) for img in tensor ]
118- return images [0 ] if len (images ) == 1 else images
57+ if do_scaling and scaling_factor is None :
58+ deprecate (
59+ "do_scaling" ,
60+ "1.0.0" ,
61+ "`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required." ,
62+ standard_warn = False ,
63+ )
11964
12065
12166def remote_decode (
12267 endpoint : str ,
12368 tensor : "torch.Tensor" ,
12469 processor : Optional [Union ["VaeImageProcessor" , "VideoProcessor" ]] = None ,
12570 do_scaling : bool = True ,
71+ scaling_factor : Optional [float ] = None ,
72+ shift_factor : Optional [float ] = None ,
12673 output_type : Literal ["mp4" , "pil" , "pt" ] = "pil" ,
12774 return_type : Literal ["mp4" , "pil" , "pt" ] = "pil" ,
12875 image_format : Literal ["png" , "jpg" ] = "jpg" ,
@@ -141,8 +88,16 @@ def remote_decode(
14188 processor (`VaeImageProcessor` or `VideoProcessor`, *optional*):
14289 Used with `return_type="pt"`, and `return_type="pil"` for Video models.
14390 do_scaling (`bool`, default `True`, *optional*):
144- When `True` scaling e.g. `latents / self.vae.config.scaling_factor` is applied remotely. If `False`, input
91+ **DEPRECATED**. When `True` scaling e.g. `latents / self.vae.config.scaling_factor` is applied remotely. If `False`, input
14592 must be passed with scaling applied.
93+ scaling_factor (`float`, *optional*):
94+ Scaling is applied when passed e.g. `latents / self.vae.config.scaling_factor`.
95+ SD v1: 0.18215
96+ SD XL: 0.13025
97+ Flux: 0.3611
98+ shift_factor (`float`, *optional*):
99+ Shift is applied when passed e.g. `latents + self.vae.config.shift_factor`.
100+ Flux: 0.1159
146101 output_type (`"mp4"` or `"pil"` or `"pt", default `"pil"):
147102 **Endpoint** output type. Subject to change. Report feedback on preferred type.
148103
@@ -190,12 +145,13 @@ def remote_decode(
190145 width (`int`, **optional**):
191146 Required for `"packed"` latents.
192147 """
193-
194148 check_inputs (
195149 endpoint ,
196150 tensor ,
197151 processor ,
198152 do_scaling ,
153+ scaling_factor ,
154+ shift_factor ,
199155 output_type ,
200156 return_type ,
201157 image_format ,
@@ -205,48 +161,96 @@ def remote_decode(
205161 height ,
206162 width ,
207163 )
208-
209- # Prepare request details.
210- headers = _prepare_headers (input_tensor_type , output_type , image_format , processor , output_tensor_type )
211- params = _prepare_parameters (tensor , do_scaling , output_type , partial_postprocess , height , width )
212- payload = _encode_tensor_data (tensor , input_tensor_type )
213-
214- response = requests .post (endpoint , params = params , headers = headers , ** payload )
164+ headers = {}
165+ parameters = {
166+ "output_type" : output_type ,
167+ "partial_postprocess" : partial_postprocess ,
168+ "shape" : list (tensor .shape ),
169+ "dtype" : str (tensor .dtype ).split ("." )[- 1 ],
170+ }
171+ if do_scaling and scaling_factor is not None :
172+ parameters ["scaling_factor" ] = scaling_factor
173+ if do_scaling and shift_factor is not None :
174+ parameters ["shift_factor" ] = shift_factor
175+ if do_scaling and scaling_factor is None :
176+ parameters ["do_scaling" ] = do_scaling
177+ elif do_scaling and scaling_factor is None and shift_factor is None :
178+ parameters ["do_scaling" ] = do_scaling
179+ if height is not None and width is not None :
180+ parameters ["height" ] = height
181+ parameters ["width" ] = width
182+ tensor_data = safetensors .torch ._tobytes (tensor , "tensor" )
183+ if input_tensor_type == "base64" :
184+ headers ["Content-Type" ] = "tensor/base64"
185+ elif input_tensor_type == "binary" :
186+ headers ["Content-Type" ] = "tensor/binary"
187+ if output_type == "pil" and image_format == "jpg" and processor is None :
188+ headers ["Accept" ] = "image/jpeg"
189+ elif output_type == "pil" and image_format == "png" and processor is None :
190+ headers ["Accept" ] = "image/png"
191+ elif (output_tensor_type == "base64" and output_type == "pt" ) or (
192+ output_tensor_type == "base64" and output_type == "pil" and processor is not None
193+ ):
194+ headers ["Accept" ] = "tensor/base64"
195+ elif (output_tensor_type == "binary" and output_type == "pt" ) or (
196+ output_tensor_type == "binary" and output_type == "pil" and processor is not None
197+ ):
198+ headers ["Accept" ] = "tensor/binary"
199+ elif output_type == "mp4" :
200+ headers ["Accept" ] = "text/plain"
201+ if input_tensor_type == "base64" :
202+ kwargs = {"json" : {"inputs" : base64 .b64encode (tensor_data ).decode ("utf-8" )}}
203+ elif input_tensor_type == "binary" :
204+ kwargs = {"data" : tensor_data }
205+ response = requests .post (endpoint , params = parameters , ** kwargs , headers = headers )
215206 if not response .ok :
216207 raise RuntimeError (response .json ())
217-
218- # Process responses that return a tensor.
219- if output_type in ("pt" ,) or (output_type == "pil" and processor is not None ):
220- tensor_bytes , tensor_params = _decode_tensor_response (response , output_tensor_type )
221- shape = tensor_params ["shape" ]
222- dtype = tensor_params ["dtype" ]
208+ if output_type == "pt" or (output_type == "pil" and processor is not None ):
209+ if output_tensor_type == "base64" :
210+ content = response .json ()
211+ output_tensor = base64 .b64decode (content ["inputs" ])
212+ parameters = content ["parameters" ]
213+ shape = parameters ["shape" ]
214+ dtype = parameters ["dtype" ]
215+ elif output_tensor_type == "binary" :
216+ output_tensor = response .content
217+ parameters = response .headers
218+ shape = json .loads (parameters ["shape" ])
219+ dtype = parameters ["dtype" ]
223220 torch_dtype = DTYPE_MAP [dtype ]
224- output_tensor = torch .frombuffer (bytearray (tensor_bytes ), dtype = torch_dtype ).reshape (shape )
225-
226- if output_type == "pt" :
227- if partial_postprocess :
228- if return_type == "pil" :
229- return _tensor_to_pil_images (output_tensor )
230- return output_tensor
221+ output_tensor = torch .frombuffer (bytearray (output_tensor ), dtype = torch_dtype ).reshape (shape )
222+ if output_type == "pt" :
223+ if partial_postprocess :
224+ if return_type == "pil" :
225+ output = [Image .fromarray (image .numpy ()) for image in output_tensor ]
226+ if len (output ) == 1 :
227+ output = output [0 ]
228+ elif return_type == "pt" :
229+ output = output_tensor
230+ else :
231+ if processor is None or return_type == "pt" :
232+ output = output_tensor
231233 else :
232- if processor is None or return_type == "pt" :
233- return output_tensor
234234 if isinstance (processor , VideoProcessor ):
235- return cast (List [ Image . Image ], processor . postprocess_video ( output_tensor , output_type = "pil" )[ 0 ])
236- return cast ( Image . Image , processor . postprocess ( output_tensor , output_type = "pil" )[ 0 ])
237-
238- if output_type == "pil" and processor is None and return_type == "pil" :
239- return Image . open ( io . BytesIO ( response . content )). convert ( "RGB" )
240-
241- if output_type == "pil" and processor is not None :
242- tensor_bytes , tensor_params = _decode_tensor_response ( response , output_tensor_type )
243- shape = tensor_params [ "shape" ]
244- dtype = tensor_params [ "dtype" ]
245- torch_dtype = DTYPE_MAP [ dtype ]
246- output_tensor = torch . frombuffer ( bytearray ( tensor_bytes ), dtype = torch_dtype ). reshape ( shape )
235+ output = cast (
236+ List [ Image . Image ],
237+ processor . postprocess_video ( output_tensor , output_type = "pil" )[ 0 ],
238+ )
239+ else :
240+ output = cast (
241+ Image . Image ,
242+ processor . postprocess ( output_tensor , output_type = "pil" )[ 0 ],
243+ )
244+ elif output_type == "pil" and return_type == "pil" and processor is None :
245+ output = Image . open ( io . BytesIO ( response . content )). convert ( "RGB" )
246+ elif output_type == "pil" and processor is not None :
247247 if return_type == "pil" :
248- return _tensor_to_pil_images (output_tensor )
249- return output_tensor
250-
251- if output_type == "mp4" and return_type == "mp4" :
252- return response .content
248+ output = [
249+ Image .fromarray (image )
250+ for image in (output_tensor .permute (0 , 2 , 3 , 1 ).float ().numpy () * 255 ).round ().astype ("uint8" )
251+ ]
252+ elif return_type == "pt" :
253+ output = output_tensor
254+ elif output_type == "mp4" and return_type == "mp4" :
255+ output = response .content
256+ return output
0 commit comments