@@ -52,7 +52,7 @@ def _validate_bigquery_metadata(
5252 (not fields and not condition_value_fn )):
5353 raise ValueError (
5454 "Please provide exactly one of `fields` or "
55- "`condition_value_fn`" )
55+ "`condition_value_fn` for matching responses to requests. " )
5656
5757
5858class BigQueryEnrichmentHandler (EnrichmentSourceHandler [Union [Row , list [Row ]],
@@ -106,17 +106,15 @@ def __init__(
106106 project: Google Cloud project ID for the BigQuery table.
107107 table_name (str): Fully qualified BigQuery table name
108108 in the format `project.dataset.table`.
109- row_restriction_template (str): A template string for the `WHERE` clause
110- in the BigQuery query with placeholders (`{}`) to dynamically filter
111- rows based on input data.
109+ row_restriction_template (str): A string for the `WHERE` clause
110+ in the BigQuery query. Used as-is without any formatting.
112111 fields: (Optional[list[str]]) List of field names present in the input
113- `beam.Row`. These are used to construct the WHERE clause
114- (if `condition_value_fn` is not provided).
112+ `beam.Row`. Used for matching responses to requests.
115113 column_names: (Optional[list[str]]) Names of columns to select from the
116114 BigQuery table. If not provided, all columns (`*`) are selected.
117115 condition_value_fn: (Optional[Callable[[beam.Row], Any]]) A function
118- that takes a `beam.Row` and returns a list of value to populate in the
119- placeholder `{}` of `WHERE` clause in the query .
116+ that takes a `beam.Row` and returns a list of values. Used for matching
117+ responses to requests .
120118 query_fn: (Optional[Callable[[beam.Row], str]]) A function that takes a
121119 `beam.Row` and returns a complete BigQuery SQL query string.
122120 min_batch_size (int): Minimum number of rows to batch together when
@@ -187,29 +185,25 @@ def create_row_key(self, row: beam.Row):
187185
188186 def __call__ (self , request : Union [beam .Row , list [beam .Row ]], * args , ** kwargs ):
189187 if isinstance (request , list ):
190- values = []
191188 responses = []
192189 requests_map : dict [Any , Any ] = {}
193- batch_size = len (request )
194- raw_query = self .query_template
195- if batch_size > 1 :
196- batched_condition_template = ' or ' .join (
197- [fr'({ self .row_restriction_template } )' ] * batch_size )
198- raw_query = self .query_template .replace (
199- self .row_restriction_template , batched_condition_template )
190+ if self .fields and len (self .fields ) > 0 :
191+ unique_values = set ()
192+ field_name = self .fields [0 ]
193+ for req in request :
194+ req_dict = req ._asdict ()
195+ unique_values .add (req_dict [field_name ])
196+ if unique_values :
197+ conditions = [f"{ field_name } = '{ val } '" for val in unique_values ]
198+ raw_query = "SELECT %s FROM %s WHERE %s" % (
199+ self .select_fields , self .table_name , " OR " .join (conditions ))
200+ else :
201+ raw_query = self .query_template
202+ else :
203+ raw_query = self .query_template
200204 for req in request :
201- request_dict = req ._asdict ()
202- try :
203- current_values = (
204- self .condition_value_fn (req ) if self .condition_value_fn else
205- [request_dict [field ] for field in self .fields ])
206- except KeyError as e :
207- raise KeyError (
208- "Make sure the values passed in `fields` are the "
209- "keys in the input `beam.Row`." + str (e ))
210- values .extend (current_values )
211205 requests_map [self .create_row_key (req )] = req
212- query = raw_query . format ( * values )
206+ query = raw_query
213207
214208 responses_dict = self ._execute_query (query )
215209 unmatched_requests = requests_map .copy ()
@@ -220,6 +214,11 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
220214 if response_key in unmatched_requests :
221215 req = unmatched_requests .pop (response_key )
222216 responses .append ((req , response_row ))
217+ if unmatched_requests and responses_dict :
218+ response_row = beam .Row (** responses_dict [0 ])
219+ for req in unmatched_requests .values ():
220+ responses .append ((req , response_row ))
221+ unmatched_requests .clear ()
223222 if unmatched_requests :
224223 if self .throw_exception_on_empty_results :
225224 raise ValueError (f"no matching row found for query: { query } " )
@@ -229,17 +228,12 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
229228 responses .append ((req , beam .Row ()))
230229 return responses
231230 else :
232- request_dict = request ._asdict ()
233231 if self .query_fn :
234232 # if a query_fn is provided then it return a list of values
235233 # that should be populated into the query template string.
236234 query = self .query_fn (request )
237235 else :
238- values = (
239- self .condition_value_fn (request ) if self .condition_value_fn else
240- list (map (request_dict .get , self .fields )))
241- # construct the query.
242- query = self .query_template .format (* values )
236+ query = self .query_template
243237 response_dict = self ._execute_query (query )
244238 if response_dict is None :
245239 if self .throw_exception_on_empty_results :
0 commit comments