Skip to content

Commit 6c0ed7e

Browse files
committed
Fix BigQueryEnrichmentHandler
1 parent 3d5d04b commit 6c0ed7e

File tree

2 files changed

+45
-28
lines changed

2 files changed

+45
-28
lines changed

sdks/python/apache_beam/transforms/enrichment_handlers/bigquery.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,12 @@ def __init__(
141141
self.project = project
142142
self.column_names = column_names
143143
self.select_fields = ",".join(column_names) if column_names else '*'
144-
self.row_restriction_template = row_restriction_template
144+
self.row_restriction_template = row_restriction_template.replace('{{}}', '{}')
145145
self.table_name = table_name
146146
self.fields = fields if fields else []
147147
self.condition_value_fn = condition_value_fn
148148
self.query_fn = query_fn
149+
self._has_placeholders = '{}' in self.row_restriction_template
149150
self.query_template = (
150151
"SELECT %s FROM %s WHERE %s" %
151152
(self.select_fields, self.table_name, self.row_restriction_template))
@@ -197,29 +198,41 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
197198
[fr'({self.row_restriction_template})'] * batch_size)
198199
raw_query = self.query_template.replace(
199200
self.row_restriction_template, batched_condition_template)
200-
for req in request:
201-
request_dict = req._asdict()
202-
try:
203-
current_values = (
204-
self.condition_value_fn(req) if self.condition_value_fn else
205-
[request_dict[field] for field in self.fields])
206-
except KeyError as e:
207-
raise KeyError(
208-
"Make sure the values passed in `fields` are the "
209-
"keys in the input `beam.Row`." + str(e))
210-
values.extend(current_values)
211-
requests_map[self.create_row_key(req)] = req
212-
query = raw_query.format(*values)
201+
if self._has_placeholders:
202+
for req in request:
203+
request_dict = req._asdict()
204+
try:
205+
current_values = (
206+
self.condition_value_fn(req) if self.condition_value_fn else
207+
[request_dict[field] for field in self.fields])
208+
except KeyError as e:
209+
raise KeyError(
210+
"Make sure the values passed in `fields` are the "
211+
"keys in the input `beam.Row`." + str(e))
212+
values.extend(current_values)
213+
requests_map[self.create_row_key(req)] = req
214+
query = raw_query.format(*values)
215+
else:
216+
for req in request:
217+
requests_map[id(req)] = req # Use object id as key
218+
query = raw_query
213219

214220
responses_dict = self._execute_query(query)
215221
unmatched_requests = requests_map.copy()
216222
if responses_dict:
217-
for response in responses_dict:
218-
response_row = beam.Row(**response)
219-
response_key = self.create_row_key(response_row)
220-
if response_key in unmatched_requests:
221-
req = unmatched_requests.pop(response_key)
222-
responses.append((req, response_row))
223+
if self._has_placeholders:
224+
for response in responses_dict:
225+
response_row = beam.Row(**response)
226+
response_key = self.create_row_key(response_row)
227+
if response_key in unmatched_requests:
228+
req = unmatched_requests.pop(response_key)
229+
responses.append((req, response_row))
230+
else:
231+
if responses_dict:
232+
response_row = beam.Row(**responses_dict[0])
233+
for req in unmatched_requests.values():
234+
responses.append((req, response_row))
235+
unmatched_requests.clear()
223236
if unmatched_requests:
224237
if self.throw_exception_on_empty_results:
225238
raise ValueError(f"no matching row found for query: {query}")
@@ -235,11 +248,13 @@ def __call__(self, request: Union[beam.Row, list[beam.Row]], *args, **kwargs):
235248
# that should be populated into the query template string.
236249
query = self.query_fn(request)
237250
else:
238-
values = (
239-
self.condition_value_fn(request) if self.condition_value_fn else
240-
list(map(request_dict.get, self.fields)))
241-
# construct the query.
242-
query = self.query_template.format(*values)
251+
if self._has_placeholders:
252+
values = (
253+
self.condition_value_fn(request) if self.condition_value_fn else
254+
list(map(request_dict.get, self.fields)))
255+
query = self.query_template.format(*values)
256+
else:
257+
query = self.query_template
243258
response_dict = self._execute_query(query)
244259
if response_dict is None:
245260
if self.throw_exception_on_empty_results:

sdks/python/apache_beam/yaml/extended_tests/data/enrichment.yaml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ pipelines:
6262
project: apache-beam-testing
6363
table_name: "{BQ_TABLE}"
6464
fields: ['label']
65-
row_restriction_template: "label = '37a'"
65+
row_restriction_template: "label = '{{}}'"
6666
timeout: 30
67-
67+
6868
- type: MapToFields
6969
config:
7070
language: python
@@ -82,6 +82,8 @@ pipelines:
8282
- type: AssertEqual
8383
config:
8484
elements:
85+
- {label: '11a', rank: 0, name: 'S1'}
8586
- {label: '37a', rank: 1, name: 'S2'}
87+
- {label: '389a', rank: 2, name: 'S3'}
8688
options:
87-
yaml_experimental_features: [ 'Enrichment' ]
89+
yaml_experimental_features: [ 'Enrichment' ]

0 commit comments

Comments
 (0)