Skip to content

Commit b4d061e

Browse files
committed
[BEAM-34076] Added TTL caching for BigQuery table definitions
1 parent c111c8b commit b4d061e

File tree

2 files changed

+109
-18
lines changed

2 files changed

+109
-18
lines changed

sdks/python/apache_beam/io/gcp/bigquery_tools.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import sys
3838
import time
3939
import uuid
40+
import threading
4041
from json.decoder import JSONDecodeError
4142
from typing import Optional
4243
from typing import Sequence
@@ -66,6 +67,8 @@
6667
from apache_beam.typehints.typehints import Any
6768
from apache_beam.utils import retry
6869
from apache_beam.utils.histogram import LinearBucket
70+
from cachetools import TTLCache, cachedmethod, Cache
71+
from cachetools.keys import hashkey
6972

7073
# Protect against environments where bigquery library is not available.
7174
try:
@@ -139,6 +142,12 @@ class ExportCompression(object):
139142
SNAPPY = 'SNAPPY'
140143
NONE = 'NONE'
141144

145+
class _NonNoneTTLCache(TTLCache):
146+
"""TTLCache that does not store None values."""
147+
def __setitem__(self, key, value, cache_setitem=Cache.__setitem__):
148+
if value is not None:
149+
super().__setitem__(key=key, value=value)
150+
142151

143152
def default_encoder(obj):
144153
if isinstance(obj, decimal.Decimal):
@@ -359,6 +368,9 @@ class BigQueryWrapper(object):
359368

360369
HISTOGRAM_METRIC_LOGGER = MetricLogger()
361370

371+
_TABLE_CACHE = _NonNoneTTLCache(maxsize=1024, ttl=300)
372+
_TABLE_CACHE_LOCK = threading.RLock()
373+
362374
def __init__(self, client=None, temp_dataset_id=None, temp_table_ref=None):
363375
self.client = client or BigQueryWrapper._bigquery_client(PipelineOptions())
364376
self.gcp_bq_client = client or gcp_bigquery.Client(
@@ -788,27 +800,32 @@ def _insert_all_rows(
788800
int(time.time() * 1000) - started_millis)
789801
return not errors, errors
790802

803+
@cachedmethod(
804+
cache=lambda self: self._TABLE_CACHE,
805+
lock=lambda self: self._TABLE_CACHE_LOCK,
806+
key=lambda self, project_id, dataset_id, table_id: hashkey(
807+
project_id, dataset_id, table_id),
808+
)
791809
@retry.with_exponential_backoff(
792-
num_retries=MAX_RETRIES,
793-
retry_filter=retry.retry_on_server_errors_timeout_or_quota_issues_filter)
810+
num_retries=MAX_RETRIES,
811+
retry_filter=retry.retry_on_server_errors_timeout_or_quota_issues_filter)
794812
def get_table(self, project_id, dataset_id, table_id):
795-
"""Lookup a table's metadata object.
796-
797-
Args:
798-
client: bigquery.BigqueryV2 instance
799-
project_id: table lookup parameter
800-
dataset_id: table lookup parameter
801-
table_id: table lookup parameter
802-
803-
Returns:
804-
bigquery.Table instance
805-
Raises:
806-
HttpError: if lookup failed.
807-
"""
813+
"""Lookup a table's metadata object. (TTL cached at class level).
814+
815+
Args:
816+
client: bigquery.BigqueryV2 instance
817+
project_id: table lookup parameter
818+
dataset_id: table lookup parameter
819+
table_id: table lookup parameter
820+
821+
Returns:
822+
bigquery.Table instance
823+
Raises:
824+
HttpError: if lookup failed.
825+
"""
808826
request = bigquery.BigqueryTablesGetRequest(
809-
projectId=project_id, datasetId=dataset_id, tableId=table_id)
810-
response = self.client.tables.Get(request)
811-
return response
827+
projectId=project_id, datasetId=dataset_id, tableId=table_id)
828+
return self.client.tables.Get(request)
812829

813830
def _create_table(
814831
self,

sdks/python/apache_beam/io/gcp/bigquery_tools_test.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,80 @@ def test_temporary_dataset_is_unique(self, patched_time_sleep):
292292
wrapper.create_temporary_dataset('project-id', 'location')
293293
self.assertTrue(client.datasets.Get.called)
294294

295+
def test_get_table_invokes_tables_get_and_caches_result(self):
296+
297+
from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper
298+
299+
client = mock.Mock()
300+
client.tables = mock.Mock()
301+
302+
returned_table = mock.Mock(name="BigQueryTable")
303+
client.tables.Get = mock.Mock(return_value=returned_table)
304+
305+
wrapper = BigQueryWrapper(client=client)
306+
307+
project_id = "my-project"
308+
dataset_id = "my_dataset"
309+
table_id = "my_table"
310+
311+
table1 = wrapper.get_table(project_id, dataset_id, table_id)
312+
313+
assert table1 is returned_table
314+
assert client.tables.Get.call_count == 1
315+
316+
(request,), _ = client.tables.Get.call_args
317+
assert isinstance(request, bigquery.BigqueryTablesGetRequest)
318+
assert request.projectId == project_id
319+
assert request.datasetId == dataset_id
320+
assert request.tableId == table_id
321+
322+
table2 = wrapper.get_table(project_id, dataset_id, table_id)
323+
324+
assert table2 is returned_table
325+
assert client.tables.Get.call_count == 1 # still 1 => cached
326+
327+
def test_get_table_shared_cache_across_wrapper_instances(self):
328+
from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper
329+
330+
# ensure isolation -> clear the shared cache before the test
331+
BigQueryWrapper._TABLE_CACHE.clear()
332+
333+
client = mock.Mock()
334+
client.tables = mock.Mock()
335+
336+
returned_table = mock.Mock(name="BigQueryTable")
337+
client.tables.Get = mock.Mock(return_value=returned_table)
338+
339+
project_id = "my-project"
340+
dataset_id = "my_dataset"
341+
table_id = "my_table"
342+
343+
w1 = BigQueryWrapper(client=client)
344+
w2 = BigQueryWrapper(client=client)
345+
w3 = BigQueryWrapper(client=client)
346+
347+
# first call -> populate cache
348+
t1 = w1.get_table(project_id, dataset_id, table_id)
349+
assert t1 is returned_table
350+
assert client.tables.Get.call_count == 1
351+
352+
# verify request shape (from first call)
353+
(request,), _ = client.tables.Get.call_args
354+
assert isinstance(request, bigquery.BigqueryTablesGetRequest)
355+
assert request.projectId == project_id
356+
assert request.datasetId == dataset_id
357+
assert request.tableId == table_id
358+
359+
# calls from DIFFERENT wrapper instances -> hit the SAME cache entry
360+
t2 = w2.get_table(project_id, dataset_id, table_id)
361+
t3 = w3.get_table(project_id, dataset_id, table_id)
362+
363+
assert t2 is returned_table
364+
assert t3 is returned_table
365+
366+
# still 1 -> record cached across instances
367+
assert client.tables.Get.call_count == 1
368+
295369
def test_get_or_create_dataset_created(self):
296370
client = mock.Mock()
297371
client.datasets.Get.side_effect = HttpError(

0 commit comments

Comments
 (0)