diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index 2df66aadcc64..c63478dc0cfc 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -22,6 +22,8 @@ import collections import contextlib +import hashlib +import hmac import logging import random import re @@ -32,10 +34,14 @@ from collections.abc import Iterable from typing import TYPE_CHECKING from typing import Any +from typing import List from typing import Optional +from typing import Tuple from typing import TypeVar from typing import Union +from cryptography.fernet import Fernet + import apache_beam as beam from apache_beam import coders from apache_beam import pvalue @@ -88,6 +94,8 @@ 'BatchElements', 'CoGroupByKey', 'Distinct', + 'GcpSecret', + 'GroupByEncryptedKey', 'Keys', 'KvSwap', 'LogElements', @@ -95,6 +103,7 @@ 'Reify', 'RemoveDuplicates', 'Reshuffle', + 'Secret', 'ToString', 'Tee', 'Values', @@ -317,6 +326,205 @@ def RemoveDuplicates(pcoll): return pcoll | 'RemoveDuplicates' >> Distinct() +class Secret(): + """A secret management class used for handling sensitive data. + + This class provides a generic interface for secret management. Implementations + of this class should handle fetching secrets from a secret management system. + """ + def get_secret_bytes(self) -> bytes: + """Returns the secret as a byte string.""" + raise NotImplementedError() + + @staticmethod + def generate_secret_bytes() -> bytes: + """Generates a new secret key.""" + return Fernet.generate_key() + + +class GcpSecret(Secret): + """A secret manager implementation that retrieves secrets from Google Cloud + Secret Manager. + """ + def __init__(self, version_name: str): + """Initializes a GcpSecret object. + + Args: + version_name: The full version name of the secret in Google Cloud Secret + Manager. For example: + projects//secrets//versions/1. + For more info, see + https://cloud.google.com/python/docs/reference/secretmanager/latest/google.cloud.secretmanager_v1beta1.services.secret_manager_service.SecretManagerServiceClient#google_cloud_secretmanager_v1beta1_services_secret_manager_service_SecretManagerServiceClient_access_secret_version + """ + self._version_name = version_name + + def get_secret_bytes(self) -> bytes: + try: + from google.cloud import secretmanager + client = secretmanager.SecretManagerServiceClient() + response = client.access_secret_version( + request={"name": self._version_name}) + secret = response.payload.data + return secret + except Exception as e: + raise RuntimeError(f'Failed to retrieve secret bytes with excetion {e}') + + +class _EncryptMessage(DoFn): + """A DoFn that encrypts the key and value of each element.""" + def __init__( + self, + hmac_key_secret: Secret, + key_coder: coders.Coder, + value_coder: coders.Coder): + self.hmac_key_secret = hmac_key_secret + self.key_coder = key_coder + self.value_coder = value_coder + + def setup(self): + self._hmac_key = self.hmac_key_secret.get_secret_bytes() + self.fernet = Fernet(self._hmac_key) + + def process(self, + element: Any) -> Iterable[Tuple[bytes, Tuple[bytes, bytes]]]: + """Encrypts the key and value of an element. + + Args: + element: A tuple containing the key and value to be encrypted. + + Yields: + A tuple containing the HMAC of the encoded key, and a tuple of the + encrypted key and value. + """ + k, v = element + encoded_key = self.key_coder.encode(k) + encoded_value = self.value_coder.encode(v) + hmac_encoded_key = hmac.new(self._hmac_key, encoded_key, + hashlib.sha256).digest() + out_element = ( + hmac_encoded_key, + (self.fernet.encrypt(encoded_key), self.fernet.encrypt(encoded_value))) + yield out_element + + +class _DecryptMessage(DoFn): + """A DoFn that decrypts the key and value of each element.""" + def __init__( + self, + hmac_key_secret: Secret, + key_coder: coders.Coder, + value_coder: coders.Coder): + self.hmac_key_secret = hmac_key_secret + self.key_coder = key_coder + self.value_coder = value_coder + + def setup(self): + hmac_key = self.hmac_key_secret.get_secret_bytes() + self.fernet = Fernet(hmac_key) + + def decode_value(self, encoded_element: Tuple[bytes, bytes]) -> Any: + encrypted_value = encoded_element[1] + encoded_value = self.fernet.decrypt(encrypted_value) + real_val = self.value_coder.decode(encoded_value) + return real_val + + def filter_elements_by_key( + self, + encrypted_key: bytes, + encoded_elements: Iterable[Tuple[bytes, bytes]]) -> Iterable[Any]: + for e in encoded_elements: + if encrypted_key == self.fernet.decrypt(e[0]): + yield self.decode_value(e) + + # Right now, GBK always returns a list of elements, so we match this behavior + # here. This does mean that the whole list will be materialized every time, + # but passing an Iterable containing an Iterable breaks when pickling happens + def process( + self, element: Tuple[bytes, Iterable[Tuple[bytes, bytes]]] + ) -> Iterable[Tuple[Any, List[Any]]]: + """Decrypts the key and values of an element. + + Args: + element: A tuple containing the HMAC of the encoded key and an iterable + of tuples of encrypted keys and values. + + Yields: + A tuple containing the decrypted key and a list of decrypted values. + """ + unused_hmac_encoded_key, encoded_elements = element + seen_keys = set() + + # Since there could be hmac collisions, we will use the fernet encrypted + # key to confirm that the mapping is actually correct. + for e in encoded_elements: + encrypted_key, unused_encrypted_value = e + encoded_key = self.fernet.decrypt(encrypted_key) + if encoded_key in seen_keys: + continue + seen_keys.add(encoded_key) + real_key = self.key_coder.decode(encoded_key) + + yield ( + real_key, + list(self.filter_elements_by_key(encoded_key, encoded_elements))) + + +@typehints.with_input_types(Tuple[K, V]) +@typehints.with_output_types(Tuple[K, Iterable[V]]) +class GroupByEncryptedKey(PTransform): + """A PTransform that provides a secure alternative to GroupByKey. + + This transform encrypts the keys of the input PCollection, performs a + GroupByKey on the encrypted keys, and then decrypts the keys in the output. + This is useful when the keys contain sensitive data that should not be + stored at rest by the runner. Note the following caveats: + + 1) Runners can implement arbitrary materialization steps, so this does not + guarantee that the whole pipeline will not have unencrypted data at rest by + itself. + 2) If using this transform in streaming mode, this transform may not properly + handle update compatibility checks around coders. This means that an improper + update could lead to invalid coders, causing pipeline failure or data + corruption. If you need to update, make sure that the input type passed into + this transform does not change. + """ + def __init__(self, hmac_key: Secret): + """Initializes a GroupByEncryptedKey transform. + + Args: + hmac_key: A Secret object that provides the secret key for HMAC and + encryption. For example, a GcpSecret can be used to access a secret + stored in GCP Secret Manager + """ + self._hmac_key = hmac_key + + def expand(self, pcoll): + kv_type_hint = pcoll.element_type + if kv_type_hint and kv_type_hint != typehints.Any: + coder = coders.registry.get_coder(kv_type_hint).as_deterministic_coder( + f'GroupByEncryptedKey {self.label}' + 'The key coder is not deterministic. This may result in incorrect ' + 'pipeline output. This can be fixed by adding a type hint to the ' + 'operation preceding the GroupByKey step, and for custom key ' + 'classes, by writing a deterministic custom Coder. Please see the ' + 'documentation for more details.') + if not coder.is_kv_coder(): + raise ValueError( + 'Input elements to the transform %s with stateful DoFn must be ' + 'key-value pairs.' % self) + key_coder = coder.key_coder() + value_coder = coder.value_coder() + else: + key_coder = coders.registry.get_coder(typehints.Any) + value_coder = key_coder + + return ( + pcoll + | beam.ParDo(_EncryptMessage(self._hmac_key, key_coder, value_coder)) + | beam.GroupByKey() + | beam.ParDo(_DecryptMessage(self._hmac_key, key_coder, value_coder))) + + class _BatchSizeEstimator(object): """Estimates the best size for batches given historical timing. """ diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 66e7a9e194d3..6cd8d5fcba76 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -21,19 +21,25 @@ # pylint: disable=too-many-function-args import collections +import hashlib +import hmac import importlib import logging import math import random import re +import string import time import unittest import warnings from collections.abc import Mapping from datetime import datetime +import mock import pytest import pytz +from cryptography.fernet import Fernet +from cryptography.fernet import InvalidToken from parameterized import param from parameterized import parameterized @@ -65,6 +71,8 @@ from apache_beam.transforms.core import FlatMapTuple from apache_beam.transforms.trigger import AfterCount from apache_beam.transforms.trigger import Repeatedly +from apache_beam.transforms.util import GcpSecret +from apache_beam.transforms.util import Secret from apache_beam.transforms.window import FixedWindows from apache_beam.transforms.window import GlobalWindow from apache_beam.transforms.window import GlobalWindows @@ -88,6 +96,11 @@ except ImportError: dill = None +try: + from google.cloud import secretmanager +except ImportError: + secretmanager = None # type: ignore[assignment] + warnings.filterwarnings( 'ignore', category=FutureWarning, module='apache_beam.transform.util_test') @@ -238,6 +251,134 @@ def test_co_group_by_key_on_unpickled(self): assert_that(pcoll, equal_to(expected)) +class FakeSecret(beam.Secret): + def __init__(self, should_throw=False): + self._secret = b'aKwI2PmqYFt2p5tNKCyBS5qYmHhHsGZcyZrnZQiQ-uE=' + self._should_throw = should_throw + + def get_secret_bytes(self) -> bytes: + if self._should_throw: + raise RuntimeError('Exception retrieving secret') + return self._secret + + +class MockNoOpDecrypt(beam.transforms.util._DecryptMessage): + def __init__(self, hmac_key_secret, key_coder, value_coder): + hmac_key = hmac_key_secret.get_secret_bytes() + self.fernet_tester = Fernet(hmac_key) + self.known_hmacs = [] + for key in ['a', 'b', 'c']: + self.known_hmacs.append( + hmac.new(hmac_key, key_coder.encode(key), hashlib.sha256).digest()) + super().__init__(hmac_key_secret, key_coder, value_coder) + + def process(self, element): + hmac_key, actual_elements = element + if hmac_key not in self.known_hmacs: + raise ValueError(f'GBK produced unencrypted value {hmac_key}') + for e in actual_elements: + try: + self.fernet_tester.decrypt(e[0], None) + except InvalidToken: + raise ValueError(f'GBK produced unencrypted value {e[0]}') + try: + self.fernet_tester.decrypt(e[1], None) + except InvalidToken: + raise ValueError(f'GBK produced unencrypted value {e[1]}') + + return super().process(element) + + +class GroupByEncryptedKeyTest(unittest.TestCase): + def setUp(self): + if secretmanager is not None: + self.project_id = 'apache-beam-testing' + secret_postfix = ''.join(random.choice(string.digits) for _ in range(6)) + self.secret_id = 'gbek_secret_tests_' + secret_postfix + self.client = secretmanager.SecretManagerServiceClient() + self.project_path = f'projects/{self.project_id}' + self.secret_path = f'{self.project_path}/secrets/{self.secret_id}' + try: + self.client.get_secret(request={'name': self.secret_path}) + except Exception: + self.client.create_secret( + request={ + 'parent': self.project_path, + 'secret_id': self.secret_id, + 'secret': { + 'replication': { + 'automatic': {} + } + } + }) + self.client.add_secret_version( + request={ + 'parent': self.secret_path, + 'payload': { + 'data': Secret.generate_secret_bytes() + } + }) + self.gcp_secret = GcpSecret(f'{self.secret_path}/versions/latest') + + def tearDown(self): + if secretmanager is not None: + self.client.delete_secret(request={'name': self.secret_path}) + + def test_gbek_fake_secret_manager_roundtrips(self): + fakeSecret = FakeSecret() + + with TestPipeline() as pipeline: + pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2), + ('b', 3), ('c', 4)]) + result = (pcoll_1) | beam.GroupByEncryptedKey(fakeSecret) + assert_that( + result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))])) + + @mock.patch('apache_beam.transforms.util._DecryptMessage', MockNoOpDecrypt) + def test_gbek_fake_secret_manager_actually_does_encryption(self): + fakeSecret = FakeSecret() + + with TestPipeline('FnApiRunner') as pipeline: + pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2), + ('b', 3), ('c', 4)]) + result = (pcoll_1) | beam.GroupByEncryptedKey(fakeSecret) + assert_that( + result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))])) + + def test_gbek_fake_secret_manager_throws(self): + fakeSecret = FakeSecret(True) + + with self.assertRaisesRegex(RuntimeError, r'Exception retrieving secret'): + with TestPipeline() as pipeline: + pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2), + ('b', 3), ('c', 4)]) + result = (pcoll_1) | beam.GroupByEncryptedKey(fakeSecret) + assert_that( + result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))])) + + @unittest.skipIf(secretmanager is None, 'GCP dependencies are not installed') + def test_gbek_gcp_secret_manager_roundtrips(self): + with TestPipeline() as pipeline: + pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2), + ('b', 3), ('c', 4)]) + result = (pcoll_1) | beam.GroupByEncryptedKey(self.gcp_secret) + assert_that( + result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))])) + + @unittest.skipIf(secretmanager is None, 'GCP dependencies are not installed') + def test_gbek_gcp_secret_manager_throws(self): + gcp_secret = GcpSecret('bad_path/versions/latest') + + with self.assertRaisesRegex(RuntimeError, + r'Failed to retrieve secret bytes'): + with TestPipeline() as pipeline: + pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2), + ('b', 3), ('c', 4)]) + result = (pcoll_1) | beam.GroupByEncryptedKey(gcp_secret) + assert_that( + result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))])) + + class FakeClock(object): def __init__(self, now=time.time()): self._now = now diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 102eb3ac2d17..c23d69225d52 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -359,6 +359,7 @@ def get_portability_package_data(): ext_modules=extensions, install_requires=[ 'crcmod>=1.7,<2.0', + 'cryptography>=39.0.0,<48.0.0', 'orjson>=3.9.7,<4', 'fastavro>=0.23.6,<2', 'fasteners>=0.3,<1.0', @@ -476,6 +477,7 @@ def get_portability_package_data(): # GCP Packages required by ML functionality 'google-cloud-dlp>=3.0.0,<4', 'google-cloud-language>=2.0,<3', + 'google-cloud-secret-manager>=2.0,<3', 'google-cloud-videointelligence>=2.0,<3', 'google-cloud-vision>=2,<4', 'google-cloud-recommendations-ai>=0.1.0,<0.11.0',