Skip to content

Commit 6d2d881

Browse files
committed
Fixed BigQueryEnrichmentHandler
1 parent 3d5d04b commit 6d2d881

File tree

2 files changed

+31
-35
lines changed

2 files changed

+31
-35
lines changed

sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def _validate_bigquery_metadata(
5252
(not fields and not condition_value_fn)):
5353
raise ValueError(
5454
"Please provide exactly one of `fields` or "
55-
"`condition_value_fn`")
55+
"`condition_value_fn` for matching responses to requests.")
5656

5757

5858
class BigQueryEnrichmentHandler(EnrichmentSourceHandler[Union[Row, list[Row]],
@@ -106,17 +106,15 @@ def __init__(
106106
project: Google Cloud project ID for the BigQuery table.
107107
table_name (str): Fully qualified BigQuery table name
108108
in the format `project.dataset.table`.
109-
row_restriction_template (str): A template string for the `WHERE` clause
110-
in the BigQuery query with placeholders (`{}`) to dynamically filter
111-
rows based on input data.
109+
row_restriction_template (str): A string for the `WHERE` clause
110+
in the BigQuery query. Used as-is without any formatting.
112111
fields: (Optional[list[str]]) List of field names present in the input
113-
`beam.Row`. These are used to construct the WHERE clause
114-
(if `condition_value_fn` is not provided).
112+
`beam.Row`. Used for matching responses to requests.
115113
column_names: (Optional[list[str]]) Names of columns to select from the
116114
BigQuery table. If not provided, all columns (`*`) are selected.
117115
condition_value_fn: (Optional[Callable[[beam.Row], Any]]) A function
118-
that takes a `beam.Row` and returns a list of value to populate in the
119-
placeholder `{}` of `WHERE` clause in the query.
116+
that takes a `beam.Row` and returns a list of values. Used for matching
117+
responses to requests.
120118
query_fn: (Optional[Callable[[beam.Row], str]]) A function that takes a
121119
`beam.Row` and returns a complete BigQuery SQL query string.
122120
min_batch_size (int): Minimum number of rows to batch together when
@@ -187,29 +185,25 @@ def create_row_key(self, row: beam.Row):
187185

188186
def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
189187
if isinstance(request, list):
190-
values = []
191188
responses = []
192189
requests_map: dict[Any, Any] = {}
193-
batch_size = len(request)
194-
raw_query = self.query_template
195-
if batch_size > 1:
196-
batched_condition_template = ' or '.join(
197-
[fr'({self.row_restriction_template})'] * batch_size)
198-
raw_query = self.query_template.replace(
199-
self.row_restriction_template, batched_condition_template)
190+
if self.fields and len(self.fields) > 0:
191+
unique_values = set()
192+
field_name = self.fields[0]
193+
for req in request:
194+
req_dict = req._asdict()
195+
unique_values.add(req_dict[field_name])
196+
if unique_values:
197+
conditions = [f"{field_name} = '{val}'" for val in unique_values]
198+
raw_query = "SELECT %s FROM %s WHERE %s" % (
199+
self.select_fields, self.table_name, " OR ".join(conditions))
200+
else:
201+
raw_query = self.query_template
202+
else:
203+
raw_query = self.query_template
200204
for req in request:
201-
request_dict = req._asdict()
202-
try:
203-
current_values = (
204-
self.condition_value_fn(req) if self.condition_value_fn else
205-
[request_dict[field] for field in self.fields])
206-
except KeyError as e:
207-
raise KeyError(
208-
"Make sure the values passed in `fields` are the "
209-
"keys in the input `beam.Row`." + str(e))
210-
values.extend(current_values)
211205
requests_map[self.create_row_key(req)] = req
212-
query = raw_query.format(*values)
206+
query = raw_query
213207

214208
responses_dict = self._execute_query(query)
215209
unmatched_requests = requests_map.copy()
@@ -220,6 +214,11 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
220214
if response_key in unmatched_requests:
221215
req = unmatched_requests.pop(response_key)
222216
responses.append((req, response_row))
217+
if unmatched_requests and responses_dict:
218+
response_row = beam.Row(**responses_dict[0])
219+
for req in unmatched_requests.values():
220+
responses.append((req, response_row))
221+
unmatched_requests.clear()
223222
if unmatched_requests:
224223
if self.throw_exception_on_empty_results:
225224
raise ValueError(f"no matching row found for query: {query}")
@@ -229,17 +228,12 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
229228
responses.append((req, beam.Row()))
230229
return responses
231230
else:
232-
request_dict = request._asdict()
233231
if self.query_fn:
234232
# if a query_fn is provided then it return a list of values
235233
# that should be populated into the query template string.
236234
query = self.query_fn(request)
237235
else:
238-
values = (
239-
self.condition_value_fn(request) if self.condition_value_fn else
240-
list(map(request_dict.get, self.fields)))
241-
# construct the query.
242-
query = self.query_template.format(*values)
236+
query = self.query_template
243237
response_dict = self._execute_query(query)
244238
if response_dict is None:
245239
if self.throw_exception_on_empty_results:

sdks/python/apache_beam/yaml/extended_tests/data/enrichment.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ pipelines:
6464
fields: ['label']
6565
row_restriction_template: "label = '37a'"
6666
timeout: 30
67-
67+
6868
- type: MapToFields
6969
config:
7070
language: python
@@ -82,6 +82,8 @@ pipelines:
8282
- type: AssertEqual
8383
config:
8484
elements:
85+
- {label: '11a', rank: 0, name: 'S1'}
8586
- {label: '37a', rank: 1, name: 'S2'}
87+
- {label: '389a', rank: 2, name: 'S3'}
8688
options:
87-
yaml_experimental_features: [ 'Enrichment' ]
89+
yaml_experimental_features: [ 'Enrichment' ]

0 commit comments

Comments
 (0)