Skip to content

Commit 7381121

Browse files
authored
Add GroupByEncryptedKey transform (#36213)
* Add GroupByEncryptedKey transform * Missing requirement * lint * Import order * keep type checking * feedback * comment disclaimer * doc note * Avoid secret naming conflicts
1 parent 4194a62 commit 7381121

File tree

3 files changed

+351
-0
lines changed

3 files changed

+351
-0
lines changed

sdks/python/apache_beam/transforms/util.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222

2323
import collections
2424
import contextlib
25+
import hashlib
26+
import hmac
2527
import logging
2628
import random
2729
import re
@@ -32,10 +34,14 @@
3234
from collections.abc import Iterable
3335
from typing import TYPE_CHECKING
3436
from typing import Any
37+
from typing import List
3538
from typing import Optional
39+
from typing import Tuple
3640
from typing import TypeVar
3741
from typing import Union
3842

43+
from cryptography.fernet import Fernet
44+
3945
import apache_beam as beam
4046
from apache_beam import coders
4147
from apache_beam import pvalue
@@ -88,13 +94,16 @@
8894
'BatchElements',
8995
'CoGroupByKey',
9096
'Distinct',
97+
'GcpSecret',
98+
'GroupByEncryptedKey',
9199
'Keys',
92100
'KvSwap',
93101
'LogElements',
94102
'Regex',
95103
'Reify',
96104
'RemoveDuplicates',
97105
'Reshuffle',
106+
'Secret',
98107
'ToString',
99108
'Tee',
100109
'Values',
@@ -317,6 +326,205 @@ def RemoveDuplicates(pcoll):
317326
return pcoll | 'RemoveDuplicates' >> Distinct()
318327

319328

329+
class Secret():
330+
"""A secret management class used for handling sensitive data.
331+
332+
This class provides a generic interface for secret management. Implementations
333+
of this class should handle fetching secrets from a secret management system.
334+
"""
335+
def get_secret_bytes(self) -> bytes:
336+
"""Returns the secret as a byte string."""
337+
raise NotImplementedError()
338+
339+
@staticmethod
340+
def generate_secret_bytes() -> bytes:
341+
"""Generates a new secret key."""
342+
return Fernet.generate_key()
343+
344+
345+
class GcpSecret(Secret):
346+
"""A secret manager implementation that retrieves secrets from Google Cloud
347+
Secret Manager.
348+
"""
349+
def __init__(self, version_name: str):
350+
"""Initializes a GcpSecret object.
351+
352+
Args:
353+
version_name: The full version name of the secret in Google Cloud Secret
354+
Manager. For example:
355+
projects/<id>/secrets/<secret_name>/versions/1.
356+
For more info, see
357+
https://cloud.google.com/python/docs/reference/secretmanager/latest/google.cloud.secretmanager_v1beta1.services.secret_manager_service.SecretManagerServiceClient#google_cloud_secretmanager_v1beta1_services_secret_manager_service_SecretManagerServiceClient_access_secret_version
358+
"""
359+
self._version_name = version_name
360+
361+
def get_secret_bytes(self) -> bytes:
362+
try:
363+
from google.cloud import secretmanager
364+
client = secretmanager.SecretManagerServiceClient()
365+
response = client.access_secret_version(
366+
request={"name": self._version_name})
367+
secret = response.payload.data
368+
return secret
369+
except Exception as e:
370+
raise RuntimeError(f'Failed to retrieve secret bytes with excetion {e}')
371+
372+
373+
class _EncryptMessage(DoFn):
374+
"""A DoFn that encrypts the key and value of each element."""
375+
def __init__(
376+
self,
377+
hmac_key_secret: Secret,
378+
key_coder: coders.Coder,
379+
value_coder: coders.Coder):
380+
self.hmac_key_secret = hmac_key_secret
381+
self.key_coder = key_coder
382+
self.value_coder = value_coder
383+
384+
def setup(self):
385+
self._hmac_key = self.hmac_key_secret.get_secret_bytes()
386+
self.fernet = Fernet(self._hmac_key)
387+
388+
def process(self,
389+
element: Any) -> Iterable[Tuple[bytes, Tuple[bytes, bytes]]]:
390+
"""Encrypts the key and value of an element.
391+
392+
Args:
393+
element: A tuple containing the key and value to be encrypted.
394+
395+
Yields:
396+
A tuple containing the HMAC of the encoded key, and a tuple of the
397+
encrypted key and value.
398+
"""
399+
k, v = element
400+
encoded_key = self.key_coder.encode(k)
401+
encoded_value = self.value_coder.encode(v)
402+
hmac_encoded_key = hmac.new(self._hmac_key, encoded_key,
403+
hashlib.sha256).digest()
404+
out_element = (
405+
hmac_encoded_key,
406+
(self.fernet.encrypt(encoded_key), self.fernet.encrypt(encoded_value)))
407+
yield out_element
408+
409+
410+
class _DecryptMessage(DoFn):
411+
"""A DoFn that decrypts the key and value of each element."""
412+
def __init__(
413+
self,
414+
hmac_key_secret: Secret,
415+
key_coder: coders.Coder,
416+
value_coder: coders.Coder):
417+
self.hmac_key_secret = hmac_key_secret
418+
self.key_coder = key_coder
419+
self.value_coder = value_coder
420+
421+
def setup(self):
422+
hmac_key = self.hmac_key_secret.get_secret_bytes()
423+
self.fernet = Fernet(hmac_key)
424+
425+
def decode_value(self, encoded_element: Tuple[bytes, bytes]) -> Any:
426+
encrypted_value = encoded_element[1]
427+
encoded_value = self.fernet.decrypt(encrypted_value)
428+
real_val = self.value_coder.decode(encoded_value)
429+
return real_val
430+
431+
def filter_elements_by_key(
432+
self,
433+
encrypted_key: bytes,
434+
encoded_elements: Iterable[Tuple[bytes, bytes]]) -> Iterable[Any]:
435+
for e in encoded_elements:
436+
if encrypted_key == self.fernet.decrypt(e[0]):
437+
yield self.decode_value(e)
438+
439+
# Right now, GBK always returns a list of elements, so we match this behavior
440+
# here. This does mean that the whole list will be materialized every time,
441+
# but passing an Iterable containing an Iterable breaks when pickling happens
442+
def process(
443+
self, element: Tuple[bytes, Iterable[Tuple[bytes, bytes]]]
444+
) -> Iterable[Tuple[Any, List[Any]]]:
445+
"""Decrypts the key and values of an element.
446+
447+
Args:
448+
element: A tuple containing the HMAC of the encoded key and an iterable
449+
of tuples of encrypted keys and values.
450+
451+
Yields:
452+
A tuple containing the decrypted key and a list of decrypted values.
453+
"""
454+
unused_hmac_encoded_key, encoded_elements = element
455+
seen_keys = set()
456+
457+
# Since there could be hmac collisions, we will use the fernet encrypted
458+
# key to confirm that the mapping is actually correct.
459+
for e in encoded_elements:
460+
encrypted_key, unused_encrypted_value = e
461+
encoded_key = self.fernet.decrypt(encrypted_key)
462+
if encoded_key in seen_keys:
463+
continue
464+
seen_keys.add(encoded_key)
465+
real_key = self.key_coder.decode(encoded_key)
466+
467+
yield (
468+
real_key,
469+
list(self.filter_elements_by_key(encoded_key, encoded_elements)))
470+
471+
472+
@typehints.with_input_types(Tuple[K, V])
473+
@typehints.with_output_types(Tuple[K, Iterable[V]])
474+
class GroupByEncryptedKey(PTransform):
475+
"""A PTransform that provides a secure alternative to GroupByKey.
476+
477+
This transform encrypts the keys of the input PCollection, performs a
478+
GroupByKey on the encrypted keys, and then decrypts the keys in the output.
479+
This is useful when the keys contain sensitive data that should not be
480+
stored at rest by the runner. Note the following caveats:
481+
482+
1) Runners can implement arbitrary materialization steps, so this does not
483+
guarantee that the whole pipeline will not have unencrypted data at rest by
484+
itself.
485+
2) If using this transform in streaming mode, this transform may not properly
486+
handle update compatibility checks around coders. This means that an improper
487+
update could lead to invalid coders, causing pipeline failure or data
488+
corruption. If you need to update, make sure that the input type passed into
489+
this transform does not change.
490+
"""
491+
def __init__(self, hmac_key: Secret):
492+
"""Initializes a GroupByEncryptedKey transform.
493+
494+
Args:
495+
hmac_key: A Secret object that provides the secret key for HMAC and
496+
encryption. For example, a GcpSecret can be used to access a secret
497+
stored in GCP Secret Manager
498+
"""
499+
self._hmac_key = hmac_key
500+
501+
def expand(self, pcoll):
502+
kv_type_hint = pcoll.element_type
503+
if kv_type_hint and kv_type_hint != typehints.Any:
504+
coder = coders.registry.get_coder(kv_type_hint).as_deterministic_coder(
505+
f'GroupByEncryptedKey {self.label}'
506+
'The key coder is not deterministic. This may result in incorrect '
507+
'pipeline output. This can be fixed by adding a type hint to the '
508+
'operation preceding the GroupByKey step, and for custom key '
509+
'classes, by writing a deterministic custom Coder. Please see the '
510+
'documentation for more details.')
511+
if not coder.is_kv_coder():
512+
raise ValueError(
513+
'Input elements to the transform %s with stateful DoFn must be '
514+
'key-value pairs.' % self)
515+
key_coder = coder.key_coder()
516+
value_coder = coder.value_coder()
517+
else:
518+
key_coder = coders.registry.get_coder(typehints.Any)
519+
value_coder = key_coder
520+
521+
return (
522+
pcoll
523+
| beam.ParDo(_EncryptMessage(self._hmac_key, key_coder, value_coder))
524+
| beam.GroupByKey()
525+
| beam.ParDo(_DecryptMessage(self._hmac_key, key_coder, value_coder)))
526+
527+
320528
class _BatchSizeEstimator(object):
321529
"""Estimates the best size for batches given historical timing.
322530
"""

0 commit comments

Comments
 (0)