2828from  PIL  import  Image 
2929
3030
31+ def  check_inputs (
32+     endpoint : str ,
33+     tensor : "torch.Tensor" ,
34+     processor : Optional [Union ["VaeImageProcessor" , "VideoProcessor" ]] =  None ,
35+     do_scaling : bool  =  True ,
36+     output_type : Literal ["mp4" , "pil" , "pt" ] =  "pil" ,
37+     return_type : Literal ["mp4" , "pil" , "pt" ] =  "pil" ,
38+     image_format : Literal ["png" , "jpg" ] =  "jpg" ,
39+     partial_postprocess : bool  =  False ,
40+     input_tensor_type : Literal ["base64" , "binary" ] =  "base64" ,
41+     output_tensor_type : Literal ["base64" , "binary" ] =  "base64" ,
42+     height : Optional [int ] =  None ,
43+     width : Optional [int ] =  None ,
44+ ):
45+     if  tensor .ndim  ==  3  and  height  is  None  and  width  is  None :
46+         raise  ValueError ("`height` and `width` required for packed latents." )
47+     if  (
48+         output_type  ==  "pt" 
49+         and  return_type  ==  "pil" 
50+         and  not  partial_postprocess 
51+         and  not  isinstance (processor , (VaeImageProcessor , VideoProcessor ))
52+     ):
53+         raise  ValueError ("`processor` is required." )
54+ 
55+ 
3156def  remote_decode (
3257    endpoint : str ,
3358    tensor : "torch.Tensor" ,
@@ -63,7 +88,7 @@ def remote_decode(
6388                    Requires `processor` as a flag (any `None` value will work). 
6489            `"pt"`: Support by image and video models. Endpoint returns `torch.Tensor`. 
6590                With `partial_postprocess=True` the tensor is postprocessed `uint8` image tensor. 
66-              
91+ 
6792            Recommendations: 
6893                `"pt"` with `partial_postprocess=True` is the smallest transfer for full quality. 
6994                `"pt"` with `partial_postprocess=False` is the most compatible with third party code. 
@@ -85,28 +110,38 @@ def remote_decode(
85110
86111        image_format (`"png"` or `"jpg"`, default `jpg`): 
87112            Used with `output_type="pil"`. Endpoint returns `jpg` or `png`. 
88-          
113+ 
89114        partial_postprocess (`bool`, default `False`): 
90115            Used with `output_type="pt"`. 
91116            `partial_postprocess=False` tensor is `float16` or `bfloat16`, without denormalization. 
92117            `partial_postprocess=True` tensor is `uint8`, denormalized. 
93-          
118+ 
94119        input_tensor_type (`"base64"` or `"binary"`, default `"base64"`): 
95120            With `"base64"` `tensor` is sent to endpoint base64 encoded. `"binary"` reduces overhead and transfer. 
96121
97122        output_tensor_type (`"base64"` or `"binary"`, default `"base64"`): 
98123            With `"base64"` `tensor` returned by endpoint is base64 encoded. `"binary"` reduces overhead and transfer. 
99-          
124+ 
100125        height (`int`, **optional**): 
101126            Required for `"packed"` latents. 
102127
103128        width (`int`, **optional**): 
104129            Required for `"packed"` latents. 
105130    """ 
106-     if  tensor .ndim  ==  3  and  height  is  None  and  width  is  None :
107-         raise  ValueError ("`height` and `width` required for packed latents." )
108-     if  output_type  ==  "pt"  and  partial_postprocess  is  False  and  processor  is  None :
109-         raise  ValueError ("`processor` is required with `output_type='pt' and `partial_postprocess=False`." )
131+     check_inputs (
132+         endpoint ,
133+         tensor ,
134+         processor ,
135+         do_scaling ,
136+         output_type ,
137+         return_type ,
138+         image_format ,
139+         partial_postprocess ,
140+         input_tensor_type ,
141+         output_tensor_type ,
142+         height ,
143+         width ,
144+     )
110145    headers  =  {}
111146    parameters  =  {
112147        "do_scaling" : do_scaling ,
@@ -160,11 +195,14 @@ def remote_decode(
160195        output_tensor  =  torch .frombuffer (bytearray (output_tensor ), dtype = torch_dtype ).reshape (shape )
161196    if  output_type  ==  "pt" :
162197        if  partial_postprocess :
163-             output  =  [Image .fromarray (image .numpy ()) for  image  in  output_tensor ]
164-             if  len (output ) ==  1 :
165-                 output  =  output [0 ]
198+             if  return_type  ==  "pil" :
199+                 output  =  [Image .fromarray (image .numpy ()) for  image  in  output_tensor ]
200+                 if  len (output ) ==  1 :
201+                     output  =  output [0 ]
202+             elif  return_type  ==  "pt" :
203+                 output  =  output_tensor 
166204        else :
167-             if  processor  is  None :
205+             if  processor  is  None   or   return_type   ==   "pt" :
168206                output  =  output_tensor 
169207            else :
170208                if  isinstance (processor , VideoProcessor ):
@@ -177,13 +215,16 @@ def remote_decode(
177215                        Image .Image ,
178216                        processor .postprocess (output_tensor , output_type = "pil" )[0 ],
179217                    )
180-     elif  output_type  ==  "pil"  and  processor  is  None :
218+     elif  output_type  ==  "pil"  and  return_type   ==   "pil"   and   processor  is  None :
181219        output  =  Image .open (io .BytesIO (response .content )).convert ("RGB" )
182220    elif  output_type  ==  "pil"  and  processor  is  not None :
183-         output  =  [
184-             Image .fromarray (image )
185-             for  image  in  (output_tensor .permute (0 , 2 , 3 , 1 ).float ().numpy () *  255 ).round ().astype ("uint8" )
186-         ]
187-     elif  output_type  ==  "mp4" :
221+         if  return_type  ==  "pil" :
222+             output  =  [
223+                 Image .fromarray (image )
224+                 for  image  in  (output_tensor .permute (0 , 2 , 3 , 1 ).float ().numpy () *  255 ).round ().astype ("uint8" )
225+             ]
226+         elif  return_type  ==  "pt" :
227+             output  =  output_tensor 
228+     elif  output_type  ==  "mp4"  and  return_type  ==  "mp4" :
188229        output  =  response .content 
189230    return  output 
0 commit comments