@@ -57,9 +57,9 @@ def predict_wrapper(
5757
5858  def  generate_wrapper (
5959      self , request : PostModelOutputsRequest ) ->  Iterator [service_pb2 .MultiOutputResponse ]:
60-     list_dict_input , inference_params  =  self .parse_input_request (request )
6160    if  self .download_request_urls :
6261      ensure_urls_downloaded (request )
62+     list_dict_input , inference_params  =  self .parse_input_request (request )
6363    outputs  =  self .generate (list_dict_input , inference_parameters = inference_params )
6464    for  output  in  outputs :
6565      yield  self .convert_output_to_proto (output )
@@ -71,13 +71,13 @@ def _preprocess_stream(
7171      input_data , _  =  self .parse_input_request (req )
7272      yield  input_data 
7373
74-   def  stream_wrapper (self , request : Iterator [PostModelOutputsRequest ]
74+   def  stream_wrapper (self , request_iterator : Iterator [PostModelOutputsRequest ]
7575                    ) ->  Iterator [service_pb2 .MultiOutputResponse ]:
76-     first_request  =  next (request )
77-     _ , inference_params  =  self .parse_input_request (first_request )
78-     request_iterator  =  itertools .chain ([first_request ], request )
7976    if  self .download_request_urls :
8077      request_iterator  =  readahead (map (ensure_urls_downloaded , request_iterator ))
78+     first_request  =  next (request_iterator )
79+     _ , inference_params  =  self .parse_input_request (first_request )
80+     request_iterator  =  itertools .chain ([first_request ], request_iterator )
8181    outputs  =  self .stream (self ._preprocess_stream (request_iterator ), inference_params )
8282    for  output  in  outputs :
8383      yield  self .convert_output_to_proto (output )
0 commit comments