44from  typing  import  List , Literal , Optional , Union , cast 
55
66import  requests 
7+ from  PIL  import  Image 
78
89from  .import_utils  import  is_safetensors_available , is_torch_available 
910
1617
1718    if  is_safetensors_available ():
1819        import  safetensors 
19- 
2020    DTYPE_MAP  =  {
2121        "float16" : torch .float16 ,
2222        "float32" : torch .float32 ,
2525    }
2626
2727
28- from  PIL  import  Image 
29- 
30- 
3128def  check_inputs (
3229    endpoint : str ,
3330    tensor : "torch.Tensor" ,
@@ -53,6 +50,74 @@ def check_inputs(
5350        raise  ValueError ("`processor` is required." )
5451
5552
53+ def  _prepare_headers (
54+     input_tensor_type : Literal ["base64" , "binary" ],
55+     output_type : Literal ["mp4" , "pil" , "pt" ],
56+     image_format : Literal ["png" , "jpg" ],
57+     processor : Optional [Union ["VaeImageProcessor" , "VideoProcessor" ]],
58+     output_tensor_type : Literal ["base64" , "binary" ],
59+ ) ->  dict :
60+     headers  =  {}
61+     headers ["Content-Type" ] =  "tensor/base64"  if  input_tensor_type  ==  "base64"  else  "tensor/binary" 
62+ 
63+     if  output_type  ==  "pil" :
64+         if  processor  is  None :
65+             headers ["Accept" ] =  "image/jpeg"  if  image_format  ==  "jpg"  else  "image/png" 
66+         else :
67+             headers ["Accept" ] =  "tensor/base64"  if  output_tensor_type  ==  "base64"  else  "tensor/binary" 
68+     elif  output_type  ==  "pt" :
69+         headers ["Accept" ] =  "tensor/base64"  if  output_tensor_type  ==  "base64"  else  "tensor/binary" 
70+     elif  output_type  ==  "mp4" :
71+         headers ["Accept" ] =  "text/plain" 
72+     return  headers 
73+ 
74+ 
75+ def  _prepare_parameters (
76+     tensor : "torch.Tensor" ,
77+     do_scaling : bool ,
78+     output_type : Literal ["mp4" , "pil" , "pt" ],
79+     partial_postprocess : bool ,
80+     height : Optional [int ],
81+     width : Optional [int ],
82+ ) ->  dict :
83+     params  =  {
84+         "do_scaling" : do_scaling ,
85+         "output_type" : output_type ,
86+         "partial_postprocess" : partial_postprocess ,
87+         "shape" : list (tensor .shape ),
88+         "dtype" : str (tensor .dtype ).split ("." )[- 1 ],
89+     }
90+     if  height  is  not None  and  width  is  not None :
91+         params ["height" ] =  height 
92+         params ["width" ] =  width 
93+     return  params 
94+ 
95+ 
96+ def  _encode_tensor_data (tensor : "torch.Tensor" , input_tensor_type : Literal ["base64" , "binary" ]) ->  dict :
97+     tensor_data  =  safetensors .torch ._tobytes (tensor , "tensor" )
98+     if  input_tensor_type  ==  "base64" :
99+         return  {"json" : {"inputs" : base64 .b64encode (tensor_data ).decode ("utf-8" )}}
100+     return  {"data" : tensor_data }
101+ 
102+ 
103+ def  _decode_tensor_response (response : requests .Response , output_tensor_type : Literal ["base64" , "binary" ]):
104+     if  output_tensor_type  ==  "base64" :
105+         content  =  response .json ()
106+         tensor_bytes  =  base64 .b64decode (content ["inputs" ])
107+         params  =  content ["parameters" ]
108+     else :
109+         tensor_bytes  =  response .content 
110+         params  =  response .headers .copy ()
111+         params ["shape" ] =  json .loads (params ["shape" ])
112+     return  tensor_bytes , params 
113+ 
114+ 
115+ def  _tensor_to_pil_images (tensor : "torch.Tensor" ) ->  Union [Image .Image , List [Image .Image ]]:
116+     # Assuming tensor is [batch, channels, height, width]. 
117+     images  =  [Image .fromarray ((img .permute (1 , 2 , 0 ).cpu ().numpy () *  255 ).round ().astype ("uint8" )) for  img  in  tensor ]
118+     return  images [0 ] if  len (images ) ==  1  else  images 
119+ 
120+ 
56121def  remote_decode (
57122    endpoint : str ,
58123    tensor : "torch.Tensor" ,
@@ -125,6 +190,7 @@ def remote_decode(
125190        width (`int`, **optional**): 
126191            Required for `"packed"` latents. 
127192    """ 
193+ 
128194    check_inputs (
129195        endpoint ,
130196        tensor ,
@@ -139,89 +205,48 @@ def remote_decode(
139205        height ,
140206        width ,
141207    )
142-     headers  =  {}
143-     parameters  =  {
144-         "do_scaling" : do_scaling ,
145-         "output_type" : output_type ,
146-         "partial_postprocess" : partial_postprocess ,
147-         "shape" : list (tensor .shape ),
148-         "dtype" : str (tensor .dtype ).split ("." )[- 1 ],
149-     }
150-     if  height  is  not None  and  width  is  not None :
151-         parameters ["height" ] =  height 
152-         parameters ["width" ] =  width 
153-     tensor_data  =  safetensors .torch ._tobytes (tensor , "tensor" )
154-     if  input_tensor_type  ==  "base64" :
155-         headers ["Content-Type" ] =  "tensor/base64" 
156-     elif  input_tensor_type  ==  "binary" :
157-         headers ["Content-Type" ] =  "tensor/binary" 
158-     if  output_type  ==  "pil"  and  image_format  ==  "jpg"  and  processor  is  None :
159-         headers ["Accept" ] =  "image/jpeg" 
160-     elif  output_type  ==  "pil"  and  image_format  ==  "png"  and  processor  is  None :
161-         headers ["Accept" ] =  "image/png" 
162-     elif  (output_tensor_type  ==  "base64"  and  output_type  ==  "pt" ) or  (
163-         output_tensor_type  ==  "base64"  and  output_type  ==  "pil"  and  processor  is  not None 
164-     ):
165-         headers ["Accept" ] =  "tensor/base64" 
166-     elif  (output_tensor_type  ==  "binary"  and  output_type  ==  "pt" ) or  (
167-         output_tensor_type  ==  "binary"  and  output_type  ==  "pil"  and  processor  is  not None 
168-     ):
169-         headers ["Accept" ] =  "tensor/binary" 
170-     elif  output_type  ==  "mp4" :
171-         headers ["Accept" ] =  "text/plain" 
172-     if  input_tensor_type  ==  "base64" :
173-         kwargs  =  {"json" : {"inputs" : base64 .b64encode (tensor_data ).decode ("utf-8" )}}
174-     elif  input_tensor_type  ==  "binary" :
175-         kwargs  =  {"data" : tensor_data }
176-     response  =  requests .post (endpoint , params = parameters , ** kwargs , headers = headers )
208+ 
209+     # Prepare request details. 
210+     headers  =  _prepare_headers (input_tensor_type , output_type , image_format , processor , output_tensor_type )
211+     params  =  _prepare_parameters (tensor , do_scaling , output_type , partial_postprocess , height , width )
212+     payload  =  _encode_tensor_data (tensor , input_tensor_type )
213+ 
214+     response  =  requests .post (endpoint , params = params , headers = headers , ** payload )
177215    if  not  response .ok :
178216        raise  RuntimeError (response .json ())
179-     if  output_type  ==  "pt"  or  (output_type  ==  "pil"  and  processor  is  not None ):
180-         if  output_tensor_type  ==  "base64" :
181-             content  =  response .json ()
182-             output_tensor  =  base64 .b64decode (content ["inputs" ])
183-             parameters  =  content ["parameters" ]
184-             shape  =  parameters ["shape" ]
185-             dtype  =  parameters ["dtype" ]
186-         elif  output_tensor_type  ==  "binary" :
187-             output_tensor  =  response .content 
188-             parameters  =  response .headers 
189-             shape  =  json .loads (parameters ["shape" ])
190-             dtype  =  parameters ["dtype" ]
217+ 
218+     # Process responses that return a tensor. 
219+     if  output_type  in  ("pt" ,) or  (output_type  ==  "pil"  and  processor  is  not None ):
220+         tensor_bytes , tensor_params  =  _decode_tensor_response (response , output_tensor_type )
221+         shape  =  tensor_params ["shape" ]
222+         dtype  =  tensor_params ["dtype" ]
191223        torch_dtype  =  DTYPE_MAP [dtype ]
192-         output_tensor  =  torch .frombuffer (bytearray (output_tensor ), dtype = torch_dtype ).reshape (shape )
193-     if  output_type  ==  "pt" :
194-         if  partial_postprocess :
195-             if  return_type  ==  "pil" :
196-                 output  =  [Image .fromarray (image .numpy ()) for  image  in  output_tensor ]
197-                 if  len (output ) ==  1 :
198-                     output  =  output [0 ]
199-             elif  return_type  ==  "pt" :
200-                 output  =  output_tensor 
201-         else :
202-             if  processor  is  None  or  return_type  ==  "pt" :
203-                 output  =  output_tensor 
224+         output_tensor  =  torch .frombuffer (bytearray (tensor_bytes ), dtype = torch_dtype ).reshape (shape )
225+ 
226+         if  output_type  ==  "pt" :
227+             if  partial_postprocess :
228+                 if  return_type  ==  "pil" :
229+                     return  _tensor_to_pil_images (output_tensor )
230+                 return  output_tensor 
204231            else :
232+                 if  processor  is  None  or  return_type  ==  "pt" :
233+                     return  output_tensor 
205234                if  isinstance (processor , VideoProcessor ):
206-                     output   =   cast (
207-                          List [ Image .Image ], 
208-                          processor . postprocess_video ( output_tensor ,  output_type = "pil" )[ 0 ], 
209-                     ) 
210-                  else : 
211-                      output   =   cast ( 
212-                          Image . Image , 
213-                          processor . postprocess ( output_tensor ,  output_type = "pil" )[ 0 ], 
214-                     ) 
215-     elif   output_type   ==   "pil"   and   return_type   ==   "pil"   and   processor   is   None : 
216-         output  =  Image . open ( io . BytesIO ( response . content )). convert ( "RGB" ) 
217-     elif   output_type   ==   "pil"   and   processor   is   not   None : 
235+                     return   cast (List [ Image . Image ],  processor . postprocess_video ( output_tensor ,  output_type = "pil" )[ 0 ]) 
236+                 return   cast ( Image .Image ,  processor . postprocess ( output_tensor ,  output_type = "pil" )[ 0 ]) 
237+ 
238+     if   output_type   ==   "pil"   and   processor   is   None   and   return_type   ==   "pil" : 
239+         return   Image . open ( io . BytesIO ( response . content )). convert ( "RGB" ) 
240+ 
241+     if   output_type   ==   "pil"   and   processor   is   not   None : 
242+         tensor_bytes ,  tensor_params   =   _decode_tensor_response ( response ,  output_tensor_type ) 
243+         shape   =   tensor_params [ "shape" ] 
244+          dtype   =   tensor_params [ "dtype" ] 
245+         torch_dtype  =  DTYPE_MAP [ dtype ] 
246+          output_tensor   =   torch . frombuffer ( bytearray ( tensor_bytes ),  dtype = torch_dtype ). reshape ( shape ) 
218247        if  return_type  ==  "pil" :
219-             output  =  [
220-                 Image .fromarray (image )
221-                 for  image  in  (output_tensor .permute (0 , 2 , 3 , 1 ).float ().numpy () *  255 ).round ().astype ("uint8" )
222-             ]
223-         elif  return_type  ==  "pt" :
224-             output  =  output_tensor 
225-     elif  output_type  ==  "mp4"  and  return_type  ==  "mp4" :
226-         output  =  response .content 
227-     return  output 
248+             return  _tensor_to_pil_images (output_tensor )
249+         return  output_tensor 
250+ 
251+     if  output_type  ==  "mp4"  and  return_type  ==  "mp4" :
252+         return  response .content 
0 commit comments