Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 208 additions & 0 deletions sdks/python/apache_beam/transforms/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

import collections
import contextlib
import hashlib
import hmac
import logging
import random
import re
Expand All @@ -32,10 +34,14 @@
from collections.abc import Iterable
from typing import TYPE_CHECKING
from typing import Any
from typing import List
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union

from cryptography.fernet import Fernet

import apache_beam as beam
from apache_beam import coders
from apache_beam import pvalue
Expand Down Expand Up @@ -88,13 +94,16 @@
'BatchElements',
'CoGroupByKey',
'Distinct',
'GcpSecret',
'GroupByEncryptedKey',
'Keys',
'KvSwap',
'LogElements',
'Regex',
'Reify',
'RemoveDuplicates',
'Reshuffle',
'Secret',
'ToString',
'Tee',
'Values',
Expand Down Expand Up @@ -317,6 +326,205 @@ def RemoveDuplicates(pcoll):
return pcoll | 'RemoveDuplicates' >> Distinct()


class Secret():
"""A secret management class used for handling sensitive data.

This class provides a generic interface for secret management. Implementations
of this class should handle fetching secrets from a secret management system.
"""
def get_secret_bytes(self) -> bytes:
"""Returns the secret as a byte string."""
raise NotImplementedError()

@staticmethod
def generate_secret_bytes() -> bytes:
"""Generates a new secret key."""
return Fernet.generate_key()


class GcpSecret(Secret):
"""A secret manager implementation that retrieves secrets from Google Cloud
Secret Manager.
"""
def __init__(self, version_name: str):
"""Initializes a GcpSecret object.

Args:
version_name: The full version name of the secret in Google Cloud Secret
Manager. For example:
projects/<id>/secrets/<secret_name>/versions/1.
For more info, see
https://cloud.google.com/python/docs/reference/secretmanager/latest/google.cloud.secretmanager_v1beta1.services.secret_manager_service.SecretManagerServiceClient#google_cloud_secretmanager_v1beta1_services_secret_manager_service_SecretManagerServiceClient_access_secret_version
"""
self._version_name = version_name

def get_secret_bytes(self) -> bytes:
try:
from google.cloud import secretmanager
client = secretmanager.SecretManagerServiceClient()
response = client.access_secret_version(
request={"name": self._version_name})
secret = response.payload.data
return secret
except Exception as e:
raise RuntimeError(f'Failed to retrieve secret bytes with excetion {e}')


class _EncryptMessage(DoFn):
"""A DoFn that encrypts the key and value of each element."""
def __init__(
self,
hmac_key_secret: Secret,
key_coder: coders.Coder,
value_coder: coders.Coder):
self.hmac_key_secret = hmac_key_secret
self.key_coder = key_coder
self.value_coder = value_coder

def setup(self):
self._hmac_key = self.hmac_key_secret.get_secret_bytes()
self.fernet = Fernet(self._hmac_key)

def process(self,
element: Any) -> Iterable[Tuple[bytes, Tuple[bytes, bytes]]]:
"""Encrypts the key and value of an element.

Args:
element: A tuple containing the key and value to be encrypted.

Yields:
A tuple containing the HMAC of the encoded key, and a tuple of the
encrypted key and value.
"""
k, v = element
encoded_key = self.key_coder.encode(k)
encoded_value = self.value_coder.encode(v)
hmac_encoded_key = hmac.new(self._hmac_key, encoded_key,
hashlib.sha256).digest()
out_element = (
hmac_encoded_key,
(self.fernet.encrypt(encoded_key), self.fernet.encrypt(encoded_value)))
yield out_element


class _DecryptMessage(DoFn):
"""A DoFn that decrypts the key and value of each element."""
def __init__(
self,
hmac_key_secret: Secret,
key_coder: coders.Coder,
value_coder: coders.Coder):
self.hmac_key_secret = hmac_key_secret
self.key_coder = key_coder
self.value_coder = value_coder

def setup(self):
hmac_key = self.hmac_key_secret.get_secret_bytes()
self.fernet = Fernet(hmac_key)

def decode_value(self, encoded_element: Tuple[bytes, bytes]) -> Any:
encrypted_value = encoded_element[1]
encoded_value = self.fernet.decrypt(encrypted_value)
real_val = self.value_coder.decode(encoded_value)
return real_val

def filter_elements_by_key(
self,
encrypted_key: bytes,
encoded_elements: Iterable[Tuple[bytes, bytes]]) -> Iterable[Any]:
for e in encoded_elements:
if encrypted_key == self.fernet.decrypt(e[0]):
yield self.decode_value(e)

# Right now, GBK always returns a list of elements, so we match this behavior
# here. This does mean that the whole list will be materialized every time,
# but passing an Iterable containing an Iterable breaks when pickling happens
def process(
self, element: Tuple[bytes, Iterable[Tuple[bytes, bytes]]]
) -> Iterable[Tuple[Any, List[Any]]]:
"""Decrypts the key and values of an element.

Args:
element: A tuple containing the HMAC of the encoded key and an iterable
of tuples of encrypted keys and values.

Yields:
A tuple containing the decrypted key and a list of decrypted values.
"""
unused_hmac_encoded_key, encoded_elements = element
seen_keys = set()

# Since there could be hmac collisions, we will use the fernet encrypted
# key to confirm that the mapping is actually correct.
for e in encoded_elements:
encrypted_key, unused_encrypted_value = e
encoded_key = self.fernet.decrypt(encrypted_key)
if encoded_key in seen_keys:
continue
seen_keys.add(encoded_key)
real_key = self.key_coder.decode(encoded_key)

yield (
real_key,
list(self.filter_elements_by_key(encoded_key, encoded_elements)))


@typehints.with_input_types(Tuple[K, V])
@typehints.with_output_types(Tuple[K, Iterable[V]])
class GroupByEncryptedKey(PTransform):
"""A PTransform that provides a secure alternative to GroupByKey.

This transform encrypts the keys of the input PCollection, performs a
GroupByKey on the encrypted keys, and then decrypts the keys in the output.
This is useful when the keys contain sensitive data that should not be
stored at rest by the runner. Note the following caveats:

1) Runners can implement arbitrary materialization steps, so this does not
guarantee that the whole pipeline will not have unencrypted data at rest by
itself.
2) If using this transform in streaming mode, this transform may not properly
handle update compatibility checks around coders. This means that an improper
update could lead to invalid coders, causing pipeline failure or data
corruption. If you need to update, make sure that the input type passed into
this transform does not change.
"""
def __init__(self, hmac_key: Secret):
"""Initializes a GroupByEncryptedKey transform.

Args:
hmac_key: A Secret object that provides the secret key for HMAC and
encryption. For example, a GcpSecret can be used to access a secret
stored in GCP Secret Manager
"""
self._hmac_key = hmac_key

def expand(self, pcoll):
kv_type_hint = pcoll.element_type
if kv_type_hint and kv_type_hint != typehints.Any:
coder = coders.registry.get_coder(kv_type_hint).as_deterministic_coder(
f'GroupByEncryptedKey {self.label}'
'The key coder is not deterministic. This may result in incorrect '
'pipeline output. This can be fixed by adding a type hint to the '
'operation preceding the GroupByKey step, and for custom key '
'classes, by writing a deterministic custom Coder. Please see the '
'documentation for more details.')
if not coder.is_kv_coder():
raise ValueError(
'Input elements to the transform %s with stateful DoFn must be '
'key-value pairs.' % self)
key_coder = coder.key_coder()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we assert that the key coder is deterministic here/ call self.key_coder.as_deterministic_coder here? For the encrypted key to be deterministic, we assume that the actual key coder is deterministic?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if its possible to create a new coder that wraps the original coder along with doing the encryption?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think defining this as a coder could be messy because then you would need to access the secret at construction time and include that as part of the serialized graph definition. This would then not provide sufficient security guarantees since the graph itself would have all information needed to decrypt the value.

Potentially you could include the work of downloading the secret in the coder definition, but I don't think we gain much from this today (and errors might be messy to debug). It also seems nice that we have the actual transform definition which helps make it obvious what is happening from the graph.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is a good call. Interestingly, we do check this for the direct runner GBK implementation, but not more broadly as far as I can tell. But we should definitely be verifying here.

Updated

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I think this will still circumvent update compat checks because we are pulling the original coders out of the graph?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it will, though at least the pipeline should cleanly fail in most cases when it tries to perform the encoding. For now, I will add a note to the doc string that this should be used with caution to avoid this issue.

In the future, there are maybe ways we could address this:

  1. Instead of baking this into a transform, we could create a custom coder at the SDK level which handles this (reading in the encoded bytes and then applying the encoding on top of that in bundle_processor.py)
  2. We could create a custom type and associated coder per GBEK instance which handles all encoding pieces.

Neither of these is trivial, so for now I will leave this with the doc note

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like the coder.encode could process a tuple (secret_version_name, unencrypted_key) and, use a output a (secret_version_name, encrypted_key)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this as well - it is pretty expensive though since we're serializing the secret name every time.

This is what I was trying to get around with We could create a custom type and associated coder per GBEK instance which handles all encoding pieces. since the coder definition could contain the secret name.

value_coder = coder.value_coder()
else:
key_coder = coders.registry.get_coder(typehints.Any)
value_coder = key_coder

return (
pcoll
| beam.ParDo(_EncryptMessage(self._hmac_key, key_coder, value_coder))
| beam.GroupByKey()
| beam.ParDo(_DecryptMessage(self._hmac_key, key_coder, value_coder)))


class _BatchSizeEstimator(object):
"""Estimates the best size for batches given historical timing.
"""
Expand Down
Loading
Loading