2828from PIL import Image
2929
3030
31+ def check_inputs (
32+ endpoint : str ,
33+ tensor : "torch.Tensor" ,
34+ processor : Optional [Union ["VaeImageProcessor" , "VideoProcessor" ]] = None ,
35+ do_scaling : bool = True ,
36+ output_type : Literal ["mp4" , "pil" , "pt" ] = "pil" ,
37+ return_type : Literal ["mp4" , "pil" , "pt" ] = "pil" ,
38+ image_format : Literal ["png" , "jpg" ] = "jpg" ,
39+ partial_postprocess : bool = False ,
40+ input_tensor_type : Literal ["base64" , "binary" ] = "base64" ,
41+ output_tensor_type : Literal ["base64" , "binary" ] = "base64" ,
42+ height : Optional [int ] = None ,
43+ width : Optional [int ] = None ,
44+ ):
45+ if tensor .ndim == 3 and height is None and width is None :
46+ raise ValueError ("`height` and `width` required for packed latents." )
47+ if (
48+ output_type == "pt"
49+ and return_type == "pil"
50+ and not partial_postprocess
51+ and not isinstance (processor , (VaeImageProcessor , VideoProcessor ))
52+ ):
53+ raise ValueError ("`processor` is required." )
54+
55+
3156def remote_decode (
3257 endpoint : str ,
3358 tensor : "torch.Tensor" ,
@@ -63,7 +88,7 @@ def remote_decode(
6388 Requires `processor` as a flag (any `None` value will work).
6489 `"pt"`: Support by image and video models. Endpoint returns `torch.Tensor`.
6590 With `partial_postprocess=True` the tensor is postprocessed `uint8` image tensor.
66-
91+
6792 Recommendations:
6893 `"pt"` with `partial_postprocess=True` is the smallest transfer for full quality.
6994 `"pt"` with `partial_postprocess=False` is the most compatible with third party code.
@@ -85,28 +110,38 @@ def remote_decode(
85110
86111 image_format (`"png"` or `"jpg"`, default `jpg`):
87112 Used with `output_type="pil"`. Endpoint returns `jpg` or `png`.
88-
113+
89114 partial_postprocess (`bool`, default `False`):
90115 Used with `output_type="pt"`.
91116 `partial_postprocess=False` tensor is `float16` or `bfloat16`, without denormalization.
92117 `partial_postprocess=True` tensor is `uint8`, denormalized.
93-
118+
94119 input_tensor_type (`"base64"` or `"binary"`, default `"base64"`):
95120 With `"base64"` `tensor` is sent to endpoint base64 encoded. `"binary"` reduces overhead and transfer.
96121
97122 output_tensor_type (`"base64"` or `"binary"`, default `"base64"`):
98123 With `"base64"` `tensor` returned by endpoint is base64 encoded. `"binary"` reduces overhead and transfer.
99-
124+
100125 height (`int`, **optional**):
101126 Required for `"packed"` latents.
102127
103128 width (`int`, **optional**):
104129 Required for `"packed"` latents.
105130 """
106- if tensor .ndim == 3 and height is None and width is None :
107- raise ValueError ("`height` and `width` required for packed latents." )
108- if output_type == "pt" and partial_postprocess is False and processor is None :
109- raise ValueError ("`processor` is required with `output_type='pt' and `partial_postprocess=False`." )
131+ check_inputs (
132+ endpoint ,
133+ tensor ,
134+ processor ,
135+ do_scaling ,
136+ output_type ,
137+ return_type ,
138+ image_format ,
139+ partial_postprocess ,
140+ input_tensor_type ,
141+ output_tensor_type ,
142+ height ,
143+ width ,
144+ )
110145 headers = {}
111146 parameters = {
112147 "do_scaling" : do_scaling ,
@@ -160,11 +195,14 @@ def remote_decode(
160195 output_tensor = torch .frombuffer (bytearray (output_tensor ), dtype = torch_dtype ).reshape (shape )
161196 if output_type == "pt" :
162197 if partial_postprocess :
163- output = [Image .fromarray (image .numpy ()) for image in output_tensor ]
164- if len (output ) == 1 :
165- output = output [0 ]
198+ if return_type == "pil" :
199+ output = [Image .fromarray (image .numpy ()) for image in output_tensor ]
200+ if len (output ) == 1 :
201+ output = output [0 ]
202+ elif return_type == "pt" :
203+ output = output_tensor
166204 else :
167- if processor is None :
205+ if processor is None or return_type == "pt" :
168206 output = output_tensor
169207 else :
170208 if isinstance (processor , VideoProcessor ):
@@ -177,13 +215,16 @@ def remote_decode(
177215 Image .Image ,
178216 processor .postprocess (output_tensor , output_type = "pil" )[0 ],
179217 )
180- elif output_type == "pil" and processor is None :
218+ elif output_type == "pil" and return_type == "pil" and processor is None :
181219 output = Image .open (io .BytesIO (response .content )).convert ("RGB" )
182220 elif output_type == "pil" and processor is not None :
183- output = [
184- Image .fromarray (image )
185- for image in (output_tensor .permute (0 , 2 , 3 , 1 ).float ().numpy () * 255 ).round ().astype ("uint8" )
186- ]
187- elif output_type == "mp4" :
221+ if return_type == "pil" :
222+ output = [
223+ Image .fromarray (image )
224+ for image in (output_tensor .permute (0 , 2 , 3 , 1 ).float ().numpy () * 255 ).round ().astype ("uint8" )
225+ ]
226+ elif return_type == "pt" :
227+ output = output_tensor
228+ elif output_type == "mp4" and return_type == "mp4" :
188229 output = response .content
189230 return output
0 commit comments