|
22 | 22 |
|
23 | 23 | import collections |
24 | 24 | import contextlib |
| 25 | +import hashlib |
| 26 | +import hmac |
25 | 27 | import logging |
26 | 28 | import random |
27 | 29 | import re |
|
32 | 34 | from collections.abc import Iterable |
33 | 35 | from typing import TYPE_CHECKING |
34 | 36 | from typing import Any |
| 37 | +from typing import List |
35 | 38 | from typing import Optional |
| 39 | +from typing import Tuple |
36 | 40 | from typing import TypeVar |
37 | 41 | from typing import Union |
38 | 42 |
|
| 43 | +from cryptography.fernet import Fernet |
| 44 | + |
39 | 45 | import apache_beam as beam |
40 | 46 | from apache_beam import coders |
41 | 47 | from apache_beam import pvalue |
|
88 | 94 | 'BatchElements', |
89 | 95 | 'CoGroupByKey', |
90 | 96 | 'Distinct', |
| 97 | + 'GcpSecret', |
| 98 | + 'GroupByEncryptedKey', |
91 | 99 | 'Keys', |
92 | 100 | 'KvSwap', |
93 | 101 | 'LogElements', |
94 | 102 | 'Regex', |
95 | 103 | 'Reify', |
96 | 104 | 'RemoveDuplicates', |
97 | 105 | 'Reshuffle', |
| 106 | + 'Secret', |
98 | 107 | 'ToString', |
99 | 108 | 'Tee', |
100 | 109 | 'Values', |
@@ -317,6 +326,205 @@ def RemoveDuplicates(pcoll): |
317 | 326 | return pcoll | 'RemoveDuplicates' >> Distinct() |
318 | 327 |
|
319 | 328 |
|
| 329 | +class Secret(): |
| 330 | + """A secret management class used for handling sensitive data. |
| 331 | +
|
| 332 | + This class provides a generic interface for secret management. Implementations |
| 333 | + of this class should handle fetching secrets from a secret management system. |
| 334 | + """ |
| 335 | + def get_secret_bytes(self) -> bytes: |
| 336 | + """Returns the secret as a byte string.""" |
| 337 | + raise NotImplementedError() |
| 338 | + |
| 339 | + @staticmethod |
| 340 | + def generate_secret_bytes() -> bytes: |
| 341 | + """Generates a new secret key.""" |
| 342 | + return Fernet.generate_key() |
| 343 | + |
| 344 | + |
| 345 | +class GcpSecret(Secret): |
| 346 | + """A secret manager implementation that retrieves secrets from Google Cloud |
| 347 | + Secret Manager. |
| 348 | + """ |
| 349 | + def __init__(self, version_name: str): |
| 350 | + """Initializes a GcpSecret object. |
| 351 | +
|
| 352 | + Args: |
| 353 | + version_name: The full version name of the secret in Google Cloud Secret |
| 354 | + Manager. For example: |
| 355 | + projects/<id>/secrets/<secret_name>/versions/1. |
| 356 | + For more info, see |
| 357 | + https://cloud.google.com/python/docs/reference/secretmanager/latest/google.cloud.secretmanager_v1beta1.services.secret_manager_service.SecretManagerServiceClient#google_cloud_secretmanager_v1beta1_services_secret_manager_service_SecretManagerServiceClient_access_secret_version |
| 358 | + """ |
| 359 | + self._version_name = version_name |
| 360 | + |
| 361 | + def get_secret_bytes(self) -> bytes: |
| 362 | + try: |
| 363 | + from google.cloud import secretmanager |
| 364 | + client = secretmanager.SecretManagerServiceClient() |
| 365 | + response = client.access_secret_version( |
| 366 | + request={"name": self._version_name}) |
| 367 | + secret = response.payload.data |
| 368 | + return secret |
| 369 | + except Exception as e: |
| 370 | + raise RuntimeError(f'Failed to retrieve secret bytes with excetion {e}') |
| 371 | + |
| 372 | + |
| 373 | +class _EncryptMessage(DoFn): |
| 374 | + """A DoFn that encrypts the key and value of each element.""" |
| 375 | + def __init__( |
| 376 | + self, |
| 377 | + hmac_key_secret: Secret, |
| 378 | + key_coder: coders.Coder, |
| 379 | + value_coder: coders.Coder): |
| 380 | + self.hmac_key_secret = hmac_key_secret |
| 381 | + self.key_coder = key_coder |
| 382 | + self.value_coder = value_coder |
| 383 | + |
| 384 | + def setup(self): |
| 385 | + self._hmac_key = self.hmac_key_secret.get_secret_bytes() |
| 386 | + self.fernet = Fernet(self._hmac_key) |
| 387 | + |
| 388 | + def process(self, |
| 389 | + element: Any) -> Iterable[Tuple[bytes, Tuple[bytes, bytes]]]: |
| 390 | + """Encrypts the key and value of an element. |
| 391 | +
|
| 392 | + Args: |
| 393 | + element: A tuple containing the key and value to be encrypted. |
| 394 | +
|
| 395 | + Yields: |
| 396 | + A tuple containing the HMAC of the encoded key, and a tuple of the |
| 397 | + encrypted key and value. |
| 398 | + """ |
| 399 | + k, v = element |
| 400 | + encoded_key = self.key_coder.encode(k) |
| 401 | + encoded_value = self.value_coder.encode(v) |
| 402 | + hmac_encoded_key = hmac.new(self._hmac_key, encoded_key, |
| 403 | + hashlib.sha256).digest() |
| 404 | + out_element = ( |
| 405 | + hmac_encoded_key, |
| 406 | + (self.fernet.encrypt(encoded_key), self.fernet.encrypt(encoded_value))) |
| 407 | + yield out_element |
| 408 | + |
| 409 | + |
| 410 | +class _DecryptMessage(DoFn): |
| 411 | + """A DoFn that decrypts the key and value of each element.""" |
| 412 | + def __init__( |
| 413 | + self, |
| 414 | + hmac_key_secret: Secret, |
| 415 | + key_coder: coders.Coder, |
| 416 | + value_coder: coders.Coder): |
| 417 | + self.hmac_key_secret = hmac_key_secret |
| 418 | + self.key_coder = key_coder |
| 419 | + self.value_coder = value_coder |
| 420 | + |
| 421 | + def setup(self): |
| 422 | + hmac_key = self.hmac_key_secret.get_secret_bytes() |
| 423 | + self.fernet = Fernet(hmac_key) |
| 424 | + |
| 425 | + def decode_value(self, encoded_element: Tuple[bytes, bytes]) -> Any: |
| 426 | + encrypted_value = encoded_element[1] |
| 427 | + encoded_value = self.fernet.decrypt(encrypted_value) |
| 428 | + real_val = self.value_coder.decode(encoded_value) |
| 429 | + return real_val |
| 430 | + |
| 431 | + def filter_elements_by_key( |
| 432 | + self, |
| 433 | + encrypted_key: bytes, |
| 434 | + encoded_elements: Iterable[Tuple[bytes, bytes]]) -> Iterable[Any]: |
| 435 | + for e in encoded_elements: |
| 436 | + if encrypted_key == self.fernet.decrypt(e[0]): |
| 437 | + yield self.decode_value(e) |
| 438 | + |
| 439 | + # Right now, GBK always returns a list of elements, so we match this behavior |
| 440 | + # here. This does mean that the whole list will be materialized every time, |
| 441 | + # but passing an Iterable containing an Iterable breaks when pickling happens |
| 442 | + def process( |
| 443 | + self, element: Tuple[bytes, Iterable[Tuple[bytes, bytes]]] |
| 444 | + ) -> Iterable[Tuple[Any, List[Any]]]: |
| 445 | + """Decrypts the key and values of an element. |
| 446 | +
|
| 447 | + Args: |
| 448 | + element: A tuple containing the HMAC of the encoded key and an iterable |
| 449 | + of tuples of encrypted keys and values. |
| 450 | +
|
| 451 | + Yields: |
| 452 | + A tuple containing the decrypted key and a list of decrypted values. |
| 453 | + """ |
| 454 | + unused_hmac_encoded_key, encoded_elements = element |
| 455 | + seen_keys = set() |
| 456 | + |
| 457 | + # Since there could be hmac collisions, we will use the fernet encrypted |
| 458 | + # key to confirm that the mapping is actually correct. |
| 459 | + for e in encoded_elements: |
| 460 | + encrypted_key, unused_encrypted_value = e |
| 461 | + encoded_key = self.fernet.decrypt(encrypted_key) |
| 462 | + if encoded_key in seen_keys: |
| 463 | + continue |
| 464 | + seen_keys.add(encoded_key) |
| 465 | + real_key = self.key_coder.decode(encoded_key) |
| 466 | + |
| 467 | + yield ( |
| 468 | + real_key, |
| 469 | + list(self.filter_elements_by_key(encoded_key, encoded_elements))) |
| 470 | + |
| 471 | + |
| 472 | +@typehints.with_input_types(Tuple[K, V]) |
| 473 | +@typehints.with_output_types(Tuple[K, Iterable[V]]) |
| 474 | +class GroupByEncryptedKey(PTransform): |
| 475 | + """A PTransform that provides a secure alternative to GroupByKey. |
| 476 | +
|
| 477 | + This transform encrypts the keys of the input PCollection, performs a |
| 478 | + GroupByKey on the encrypted keys, and then decrypts the keys in the output. |
| 479 | + This is useful when the keys contain sensitive data that should not be |
| 480 | + stored at rest by the runner. Note the following caveats: |
| 481 | + |
| 482 | + 1) Runners can implement arbitrary materialization steps, so this does not |
| 483 | + guarantee that the whole pipeline will not have unencrypted data at rest by |
| 484 | + itself. |
| 485 | + 2) If using this transform in streaming mode, this transform may not properly |
| 486 | + handle update compatibility checks around coders. This means that an improper |
| 487 | + update could lead to invalid coders, causing pipeline failure or data |
| 488 | + corruption. If you need to update, make sure that the input type passed into |
| 489 | + this transform does not change. |
| 490 | + """ |
| 491 | + def __init__(self, hmac_key: Secret): |
| 492 | + """Initializes a GroupByEncryptedKey transform. |
| 493 | +
|
| 494 | + Args: |
| 495 | + hmac_key: A Secret object that provides the secret key for HMAC and |
| 496 | + encryption. For example, a GcpSecret can be used to access a secret |
| 497 | + stored in GCP Secret Manager |
| 498 | + """ |
| 499 | + self._hmac_key = hmac_key |
| 500 | + |
| 501 | + def expand(self, pcoll): |
| 502 | + kv_type_hint = pcoll.element_type |
| 503 | + if kv_type_hint and kv_type_hint != typehints.Any: |
| 504 | + coder = coders.registry.get_coder(kv_type_hint).as_deterministic_coder( |
| 505 | + f'GroupByEncryptedKey {self.label}' |
| 506 | + 'The key coder is not deterministic. This may result in incorrect ' |
| 507 | + 'pipeline output. This can be fixed by adding a type hint to the ' |
| 508 | + 'operation preceding the GroupByKey step, and for custom key ' |
| 509 | + 'classes, by writing a deterministic custom Coder. Please see the ' |
| 510 | + 'documentation for more details.') |
| 511 | + if not coder.is_kv_coder(): |
| 512 | + raise ValueError( |
| 513 | + 'Input elements to the transform %s with stateful DoFn must be ' |
| 514 | + 'key-value pairs.' % self) |
| 515 | + key_coder = coder.key_coder() |
| 516 | + value_coder = coder.value_coder() |
| 517 | + else: |
| 518 | + key_coder = coders.registry.get_coder(typehints.Any) |
| 519 | + value_coder = key_coder |
| 520 | + |
| 521 | + return ( |
| 522 | + pcoll |
| 523 | + | beam.ParDo(_EncryptMessage(self._hmac_key, key_coder, value_coder)) |
| 524 | + | beam.GroupByKey() |
| 525 | + | beam.ParDo(_DecryptMessage(self._hmac_key, key_coder, value_coder))) |
| 526 | + |
| 527 | + |
320 | 528 | class _BatchSizeEstimator(object): |
321 | 529 | """Estimates the best size for batches given historical timing. |
322 | 530 | """ |
|
0 commit comments