@@ -141,11 +141,12 @@ def __init__(
141141 self .project = project
142142 self .column_names = column_names
143143 self .select_fields = "," .join (column_names ) if column_names else '*'
144- self .row_restriction_template = row_restriction_template
144+ self .row_restriction_template = row_restriction_template . replace ( '{{}}' , '{}' )
145145 self .table_name = table_name
146146 self .fields = fields if fields else []
147147 self .condition_value_fn = condition_value_fn
148148 self .query_fn = query_fn
149+ self ._has_placeholders = '{}' in self .row_restriction_template
149150 self .query_template = (
150151 "SELECT %s FROM %s WHERE %s" %
151152 (self .select_fields , self .table_name , self .row_restriction_template ))
@@ -197,29 +198,41 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
197198 [fr'({ self .row_restriction_template } )' ] * batch_size )
198199 raw_query = self .query_template .replace (
199200 self .row_restriction_template , batched_condition_template )
200- for req in request :
201- request_dict = req ._asdict ()
202- try :
203- current_values = (
204- self .condition_value_fn (req ) if self .condition_value_fn else
205- [request_dict [field ] for field in self .fields ])
206- except KeyError as e :
207- raise KeyError (
208- "Make sure the values passed in `fields` are the "
209- "keys in the input `beam.Row`." + str (e ))
210- values .extend (current_values )
211- requests_map [self .create_row_key (req )] = req
212- query = raw_query .format (* values )
201+ if self ._has_placeholders :
202+ for req in request :
203+ request_dict = req ._asdict ()
204+ try :
205+ current_values = (
206+ self .condition_value_fn (req ) if self .condition_value_fn else
207+ [request_dict [field ] for field in self .fields ])
208+ except KeyError as e :
209+ raise KeyError (
210+ "Make sure the values passed in `fields` are the "
211+ "keys in the input `beam.Row`." + str (e ))
212+ values .extend (current_values )
213+ requests_map [self .create_row_key (req )] = req
214+ query = raw_query .format (* values )
215+ else :
216+ for req in request :
217+ requests_map [id (req )] = req # Use object id as key
218+ query = raw_query
213219
214220 responses_dict = self ._execute_query (query )
215221 unmatched_requests = requests_map .copy ()
216222 if responses_dict :
217- for response in responses_dict :
218- response_row = beam .Row (** response )
219- response_key = self .create_row_key (response_row )
220- if response_key in unmatched_requests :
221- req = unmatched_requests .pop (response_key )
222- responses .append ((req , response_row ))
223+ if self ._has_placeholders :
224+ for response in responses_dict :
225+ response_row = beam .Row (** response )
226+ response_key = self .create_row_key (response_row )
227+ if response_key in unmatched_requests :
228+ req = unmatched_requests .pop (response_key )
229+ responses .append ((req , response_row ))
230+ else :
231+ if responses_dict :
232+ response_row = beam .Row (** responses_dict [0 ])
233+ for req in unmatched_requests .values ():
234+ responses .append ((req , response_row ))
235+ unmatched_requests .clear ()
223236 if unmatched_requests :
224237 if self .throw_exception_on_empty_results :
225238 raise ValueError (f"no matching row found for query: { query } " )
@@ -235,11 +248,13 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
235248 # that should be populated into the query template string.
236249 query = self .query_fn (request )
237250 else :
238- values = (
239- self .condition_value_fn (request ) if self .condition_value_fn else
240- list (map (request_dict .get , self .fields )))
241- # construct the query.
242- query = self .query_template .format (* values )
251+ if self ._has_placeholders :
252+ values = (
253+ self .condition_value_fn (request ) if self .condition_value_fn else
254+ list (map (request_dict .get , self .fields )))
255+ query = self .query_template .format (* values )
256+ else :
257+ query = self .query_template
243258 response_dict = self ._execute_query (query )
244259 if response_dict is None :
245260 if self .throw_exception_on_empty_results :
0 commit comments