Skip to content

Commit 29b1abe

Browse files
Adding CMEK support for temp_dataset for Python Bigquery (#36118)
* Adding CMEK support for temp_dataset * Corrected formatting * Resolved conflict * formatting * Formatting * Fixing tests
1 parent 6077034 commit 29b1abe

File tree

5 files changed

+53
-8
lines changed

5 files changed

+53
-8
lines changed

sdks/python/apache_beam/io/gcp/bigquery.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,8 @@ def _setup_temporary_dataset(self, bq):
850850
return
851851
location = bq.get_query_location(
852852
self._get_project(), self.query.get(), self.use_legacy_sql)
853-
bq.create_temporary_dataset(self._get_project(), location)
853+
bq.create_temporary_dataset(
854+
self._get_project(), location, kms_key=self.kms_key)
854855

855856
@check_accessible(['query'])
856857
def _execute_query(self, bq):
@@ -1062,7 +1063,10 @@ def _setup_temporary_dataset(self, bq):
10621063
self._get_parent_project(), self.query.get(), self.use_legacy_sql)
10631064
_LOGGER.warning("### Labels: %s", str(self.bigquery_dataset_labels))
10641065
bq.create_temporary_dataset(
1065-
self._get_parent_project(), location, self.bigquery_dataset_labels)
1066+
self._get_parent_project(),
1067+
location,
1068+
self.bigquery_dataset_labels,
1069+
kms_key=self.kms_key)
10661070

10671071
@check_accessible(['query'])
10681072
def _execute_query(self, bq):

sdks/python/apache_beam/io/gcp/bigquery_read_internal.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,8 @@ def _setup_temporary_dataset(
319319
# Use the project from temp_dataset if it's a DatasetReference,
320320
# otherwise use the pipeline project
321321
temp_dataset_project = self._get_temp_dataset_project()
322-
bq.create_temporary_dataset(temp_dataset_project, location)
322+
bq.create_temporary_dataset(
323+
temp_dataset_project, location, kms_key=self.kms_key)
323324

324325
def _execute_query(
325326
self,

sdks/python/apache_beam/io/gcp/bigquery_read_internal_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_setup_temporary_dataset_uses_correct_project(self, mock_bq_wrapper):
9999

100100
# Verify that create_temporary_dataset was called with the custom project
101101
mock_bq.create_temporary_dataset.assert_called_once_with(
102-
'custom-project', 'US')
102+
'custom-project', 'US', kms_key=None)
103103
# Verify that get_query_location was called with the pipeline project
104104
mock_bq.get_query_location.assert_called_once_with(
105105
'test-project', 'SELECT * FROM table', False)
@@ -145,7 +145,7 @@ def test_setup_temporary_dataset_with_string_temp_dataset(
145145

146146
# Verify that create_temporary_dataset was called with the pipeline project
147147
mock_bq.create_temporary_dataset.assert_called_once_with(
148-
'test-project', 'US')
148+
'test-project', 'US', kms_key=None)
149149

150150
@mock.patch('apache_beam.io.gcp.bigquery_tools.BigQueryWrapper')
151151
def test_finish_bundle_with_string_temp_dataset(self, mock_bq_wrapper):

sdks/python/apache_beam/io/gcp/bigquery_tools.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,10 @@ def _build_filter_from_labels(labels):
333333
return filter_str
334334

335335

336+
def _build_dataset_encryption_config(kms_key):
337+
return bigquery.EncryptionConfiguration(kmsKeyName=kms_key)
338+
339+
336340
class BigQueryWrapper(object):
337341
"""BigQuery client wrapper with utilities for querying.
338342
@@ -835,7 +839,7 @@ def _create_table(
835839
num_retries=MAX_RETRIES,
836840
retry_filter=retry.retry_on_server_errors_and_timeout_filter)
837841
def get_or_create_dataset(
838-
self, project_id, dataset_id, location=None, labels=None):
842+
self, project_id, dataset_id, location=None, labels=None, kms_key=None):
839843
# Check if dataset already exists otherwise create it
840844
try:
841845
dataset = self.client.datasets.Get(
@@ -858,6 +862,9 @@ def get_or_create_dataset(
858862
dataset.location = location
859863
if labels is not None:
860864
dataset.labels = _build_dataset_labels(labels)
865+
if kms_key is not None:
866+
dataset.defaultEncryptionConfiguration = (
867+
_build_dataset_encryption_config(kms_key))
861868
request = bigquery.BigqueryDatasetsInsertRequest(
862869
projectId=project_id, dataset=dataset)
863870
response = self.client.datasets.Insert(request)
@@ -929,9 +936,14 @@ def is_user_configured_dataset(self):
929936
@retry.with_exponential_backoff(
930937
num_retries=MAX_RETRIES,
931938
retry_filter=retry.retry_on_server_errors_and_timeout_filter)
932-
def create_temporary_dataset(self, project_id, location, labels=None):
939+
def create_temporary_dataset(
940+
self, project_id, location, labels=None, kms_key=None):
933941
self.get_or_create_dataset(
934-
project_id, self.temp_dataset_id, location=location, labels=labels)
942+
project_id,
943+
self.temp_dataset_id,
944+
location=location,
945+
labels=labels,
946+
kms_key=kms_key)
935947

936948
if (project_id is not None and not self.is_user_configured_dataset() and
937949
not self.created_temp_dataset):

sdks/python/apache_beam/io/gcp/bigquery_tools_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,34 @@ def test_get_or_create_dataset_created(self):
301301
new_dataset = wrapper.get_or_create_dataset('project-id', 'dataset_id')
302302
self.assertEqual(new_dataset.datasetReference.datasetId, 'dataset_id')
303303

304+
def test_create_temporary_dataset_with_kms_key(self):
305+
kms_key = (
306+
'projects/my-project/locations/global/keyRings/my-kr/'
307+
'cryptoKeys/my-key')
308+
client = mock.Mock()
309+
client.datasets.Get.side_effect = HttpError(
310+
response={'status': '404'}, url='', content='')
311+
312+
client.datasets.Insert.return_value = bigquery.Dataset(
313+
datasetReference=bigquery.DatasetReference(
314+
projectId='project-id', datasetId='temp_dataset'))
315+
wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
316+
317+
try:
318+
wrapper.create_temporary_dataset(
319+
'project-id', 'location', kms_key=kms_key)
320+
except Exception:
321+
pass
322+
323+
args, _ = client.datasets.Insert.call_args
324+
insert_request = args[0] # BigqueryDatasetsInsertRequest
325+
inserted_dataset = insert_request.dataset # Actual Dataset object
326+
327+
# Assertions
328+
self.assertIsNotNone(inserted_dataset.defaultEncryptionConfiguration)
329+
self.assertEqual(
330+
inserted_dataset.defaultEncryptionConfiguration.kmsKeyName, kms_key)
331+
304332
def test_get_or_create_dataset_fetched(self):
305333
client = mock.Mock()
306334
client.datasets.Get.return_value = bigquery.Dataset(

0 commit comments

Comments
 (0)