@@ -147,7 +147,6 @@ def __init__(
147147 self .fields = fields if fields else []
148148 self .condition_value_fn = condition_value_fn
149149 self .query_fn = query_fn
150- self ._has_placeholders = '{}' in self .row_restriction_template
151150 self .query_template = (
152151 "SELECT %s FROM %s WHERE %s" %
153152 (self .select_fields , self .table_name , self .row_restriction_template ))
@@ -199,41 +198,34 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
199198 [fr'({ self .row_restriction_template } )' ] * batch_size )
200199 raw_query = self .query_template .replace (
201200 self .row_restriction_template , batched_condition_template )
202- if self ._has_placeholders :
203- for req in request :
204- request_dict = req ._asdict ()
205- try :
206- current_values = (
207- self .condition_value_fn (req ) if self .condition_value_fn else
208- [request_dict [field ] for field in self .fields ])
209- except KeyError as e :
210- raise KeyError (
211- "Make sure the values passed in `fields` are the "
212- "keys in the input `beam.Row`." + str (e ))
213- values .extend (current_values )
214- requests_map [self .create_row_key (req )] = req
215- query = raw_query .format (* values )
216- else :
217- for req in request :
218- requests_map [id (req )] = req # Use object id as key
219- query = raw_query
201+ for req in request :
202+ request_dict = req ._asdict ()
203+ try :
204+ current_values = (
205+ self .condition_value_fn (req ) if self .condition_value_fn else
206+ [request_dict [field ] for field in self .fields ])
207+ except KeyError as e :
208+ raise KeyError (
209+ "Make sure the values passed in `fields` are the "
210+ "keys in the input `beam.Row`." + str (e ))
211+ values .extend (current_values )
212+ requests_map [self .create_row_key (req )] = req
213+ query = raw_query .format (* values )
220214
221215 responses_dict = self ._execute_query (query )
222216 unmatched_requests = requests_map .copy ()
223217 if responses_dict :
224- if self ._has_placeholders :
225- for response in responses_dict :
226- response_row = beam .Row (** response )
227- response_key = self .create_row_key (response_row )
228- if response_key in unmatched_requests :
229- req = unmatched_requests .pop (response_key )
230- responses .append ((req , response_row ))
231- else :
232- if responses_dict :
233- response_row = beam .Row (** responses_dict [0 ])
234- for req in unmatched_requests .values ():
235- responses .append ((req , response_row ))
236- unmatched_requests .clear ()
218+ for response in responses_dict :
219+ response_row = beam .Row (** response )
220+ response_key = self .create_row_key (response_row )
221+ if response_key in unmatched_requests :
222+ req = unmatched_requests .pop (response_key )
223+ responses .append ((req , response_row ))
224+ if unmatched_requests and responses_dict :
225+ response_row = beam .Row (** responses_dict [0 ])
226+ for req in unmatched_requests .values ():
227+ responses .append ((req , response_row ))
228+ unmatched_requests .clear ()
237229 if unmatched_requests :
238230 if self .throw_exception_on_empty_results :
239231 raise ValueError (f"no matching row found for query: { query } " )
@@ -249,13 +241,10 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
249241 # that should be populated into the query template string.
250242 query = self .query_fn (request )
251243 else :
252- if self ._has_placeholders :
253- values = (
254- self .condition_value_fn (request ) if self .condition_value_fn else
255- list (map (request_dict .get , self .fields )))
256- query = self .query_template .format (* values )
257- else :
258- query = self .query_template
244+ values = (
245+ self .condition_value_fn (request ) if self .condition_value_fn else
246+ list (map (request_dict .get , self .fields )))
247+ query = self .query_template .format (* values )
259248 response_dict = self ._execute_query (query )
260249 if response_dict is None :
261250 if self .throw_exception_on_empty_results :
0 commit comments