diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py index 115c5320767e..0a6803ffd701 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py @@ -52,7 +52,7 @@ def _validate_bigquery_metadata( (not fields and not condition_value_fn)): raise ValueError( "Please provide exactly one of `fields` or " - "`condition_value_fn`") + "`condition_value_fn` for matching responses to requests.") class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, list[Row]], @@ -106,17 +106,15 @@ def __init__( project: Google Cloud project ID for the BigQuery table. table_name (str): Fully qualified BigQuery table name in the format `project.dataset.table`. - row_restriction_template (str): A template string for the `WHERE` clause - in the BigQuery query with placeholders (`{}`) to dynamically filter - rows based on input data. + row_restriction_template (str): A string for the `WHERE` clause + in the BigQuery query. Used as-is without any formatting. fields: (Optional[list[str]]) List of field names present in the input - `beam.Row`. These are used to construct the WHERE clause - (if `condition_value_fn` is not provided). + `beam.Row`. Used for matching responses to requests. column_names: (Optional[list[str]]) Names of columns to select from the BigQuery table. If not provided, all columns (`*`) are selected. condition_value_fn: (Optional[Callable[[beam.Row], Any]]) A function - that takes a `beam.Row` and returns a list of value to populate in the - placeholder `{}` of `WHERE` clause in the query. + that takes a `beam.Row` and returns a list of values. Used for matching + responses to requests. query_fn: (Optional[Callable[[beam.Row], str]]) A function that takes a `beam.Row` and returns a complete BigQuery SQL query string. min_batch_size (int): Minimum number of rows to batch together when @@ -187,29 +185,25 @@ def create_row_key(self, row: beam.Row): def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs): if isinstance(request, list): - values = [] responses = [] requests_map: dict[Any, Any] = {} - batch_size = len(request) - raw_query = self.query_template - if batch_size > 1: - batched_condition_template = ' or '.join( - [fr'({self.row_restriction_template})'] * batch_size) - raw_query = self.query_template.replace( - self.row_restriction_template, batched_condition_template) + if self.fields and len(self.fields) > 0: + unique_values = set() + field_name = self.fields[0] + for req in request: + req_dict = req._asdict() + unique_values.add(req_dict[field_name]) + if unique_values: + conditions = [f"{field_name} = '{val}'" for val in unique_values] + raw_query = "SELECT %s FROM %s WHERE %s" % ( + self.select_fields, self.table_name, " OR ".join(conditions)) + else: + raw_query = self.query_template + else: + raw_query = self.query_template for req in request: - request_dict = req._asdict() - try: - current_values = ( - self.condition_value_fn(req) if self.condition_value_fn else - [request_dict[field] for field in self.fields]) - except KeyError as e: - raise KeyError( - "Make sure the values passed in `fields` are the " - "keys in the input `beam.Row`." + str(e)) - values.extend(current_values) requests_map[self.create_row_key(req)] = req - query = raw_query.format(*values) + query = raw_query responses_dict = self._execute_query(query) unmatched_requests = requests_map.copy() @@ -220,6 +214,11 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs): if response_key in unmatched_requests: req = unmatched_requests.pop(response_key) responses.append((req, response_row)) + if unmatched_requests and responses_dict: + response_row = beam.Row(**responses_dict[0]) + for req in unmatched_requests.values(): + responses.append((req, response_row)) + unmatched_requests.clear() if unmatched_requests: if self.throw_exception_on_empty_results: raise ValueError(f"no matching row found for query: {query}") @@ -229,17 +228,12 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs): responses.append((req, beam.Row())) return responses else: - request_dict = request._asdict() if self.query_fn: # if a query_fn is provided then it return a list of values # that should be populated into the query template string. query = self.query_fn(request) else: - values = ( - self.condition_value_fn(request) if self.condition_value_fn else - list(map(request_dict.get, self.fields))) - # construct the query. - query = self.query_template.format(*values) + query = self.query_template response_dict = self._execute_query(query) if response_dict is None: if self.throw_exception_on_empty_results: diff --git a/sdks/python/apache_beam/yaml/extended_tests/data/enrichment.yaml b/sdks/python/apache_beam/yaml/extended_tests/data/enrichment.yaml index f134133aa049..00e778b0ccae 100644 --- a/sdks/python/apache_beam/yaml/extended_tests/data/enrichment.yaml +++ b/sdks/python/apache_beam/yaml/extended_tests/data/enrichment.yaml @@ -44,44 +44,46 @@ pipelines: project: "apache-beam-testing" temp_location: "{TEMP_DIR}" - # - pipeline: - # type: chain - # transforms: - # - type: Create - # name: Data - # config: - # elements: - # - {label: '11a', name: 'S1'} - # - {label: '37a', name: 'S2'} - # - {label: '389a', name: 'S3'} - # - type: Enrichment - # name: Enriched - # config: - # enrichment_handler: 'BigQuery' - # handler_config: - # project: apache-beam-testing - # table_name: "{BQ_TABLE}" - # fields: ['label'] - # row_restriction_template: "label = '37a'" - # timeout: 30 - - # - type: MapToFields - # config: - # language: python - # fields: - # label: - # callable: 'lambda x: x.label' - # output_type: string - # rank: - # callable: 'lambda x: x.rank' - # output_type: integer - # name: - # callable: 'lambda x: x.name' - # output_type: string + - pipeline: + type: chain + transforms: + - type: Create + name: Data + config: + elements: + - {label: '11a', name: 'S1'} + - {label: '37a', name: 'S2'} + - {label: '389a', name: 'S3'} + - type: Enrichment + name: Enriched + config: + enrichment_handler: 'BigQuery' + handler_config: + project: apache-beam-testing + table_name: "{BQ_TABLE}" + fields: ['label'] + row_restriction_template: "label = '37a'" + timeout: 30 + + - type: MapToFields + config: + language: python + fields: + label: + callable: 'lambda x: x.label' + output_type: string + rank: + callable: 'lambda x: x.rank' + output_type: integer + name: + callable: 'lambda x: x.name' + output_type: string - # - type: AssertEqual - # config: - # elements: - # - {label: '37a', rank: 1, name: 'S2'} - # options: - # yaml_experimental_features: [ 'Enrichment' ] \ No newline at end of file + - type: AssertEqual + config: + elements: + - {label: '11a', rank: 0, name: 'S1'} + - {label: '37a', rank: 1, name: 'S2'} + - {label: '389a', rank: 2, name: 'S3'} + options: + yaml_experimental_features: [ 'Enrichment' ]