diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index ed56f65ef50f..06bd728be6d7 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,5 +1,5 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run.", - "modification": 32 + "modification": 33 } diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index e941de9dfb64..595741f8efe4 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -755,6 +755,7 @@ class BeamModulePlugin implements Plugin { google_cloud_dataflow_java_proto_library_all: "com.google.cloud.dataflow:google-cloud-dataflow-java-proto-library-all:0.5.160304", google_cloud_datastore_v1_proto_client : "com.google.cloud.datastore:datastore-v1-proto-client:2.32.3", // [bomupgrader] sets version google_cloud_firestore : "com.google.cloud:google-cloud-firestore", // google_cloud_platform_libraries_bom sets version + google_cloud_kms : "com.google.cloud:google-cloud-kms", // google_cloud_platform_libraries_bom sets version google_cloud_pubsub : "com.google.cloud:google-cloud-pubsub", // google_cloud_platform_libraries_bom sets version google_cloud_pubsublite : "com.google.cloud:google-cloud-pubsublite", // google_cloud_platform_libraries_bom sets version // [bomupgrader] the BOM version is set by scripts/tools/bomupgrader.py. If update manually, also update @@ -765,6 +766,7 @@ class BeamModulePlugin implements Plugin { google_cloud_spanner_bom : "com.google.cloud:google-cloud-spanner-bom:$google_cloud_spanner_version", google_cloud_spanner : "com.google.cloud:google-cloud-spanner", // google_cloud_platform_libraries_bom sets version google_cloud_spanner_test : "com.google.cloud:google-cloud-spanner:$google_cloud_spanner_version:tests", + google_cloud_tink : "com.google.crypto.tink:tink:1.19.0", google_cloud_vertexai : "com.google.cloud:google-cloud-vertexai", // google_cloud_platform_libraries_bom sets version google_code_gson : "com.google.code.gson:gson:$google_code_gson_version", // google-http-client's version is explicitly declared for sdks/java/maven-archetypes/examples @@ -866,6 +868,7 @@ class BeamModulePlugin implements Plugin { proto_google_cloud_datacatalog_v1beta1 : "com.google.api.grpc:proto-google-cloud-datacatalog-v1beta1", // google_cloud_platform_libraries_bom sets version proto_google_cloud_datastore_v1 : "com.google.api.grpc:proto-google-cloud-datastore-v1", // google_cloud_platform_libraries_bom sets version proto_google_cloud_firestore_v1 : "com.google.api.grpc:proto-google-cloud-firestore-v1", // google_cloud_platform_libraries_bom sets version + proto_google_cloud_kms_v1 : "com.google.api.grpc:proto-google-cloud-kms-v1", // google_cloud_platform_libraries_bom sets version proto_google_cloud_pubsub_v1 : "com.google.api.grpc:proto-google-cloud-pubsub-v1", // google_cloud_platform_libraries_bom sets version proto_google_cloud_pubsublite_v1 : "com.google.api.grpc:proto-google-cloud-pubsublite-v1", // google_cloud_platform_libraries_bom sets version proto_google_cloud_secret_manager_v1 : "com.google.api.grpc:proto-google-cloud-secretmanager-v1", // google_cloud_platform_libraries_bom sets version diff --git a/sdks/java/build-tools/src/main/resources/beam/checkstyle/suppressions.xml b/sdks/java/build-tools/src/main/resources/beam/checkstyle/suppressions.xml index 53cd7b7ad4d0..ef4cbdb5ba02 100644 --- a/sdks/java/build-tools/src/main/resources/beam/checkstyle/suppressions.xml +++ b/sdks/java/build-tools/src/main/resources/beam/checkstyle/suppressions.xml @@ -57,6 +57,7 @@ + diff --git a/sdks/java/core/build.gradle b/sdks/java/core/build.gradle index 4f37ad47ec4c..74b6dfe4bba7 100644 --- a/sdks/java/core/build.gradle +++ b/sdks/java/core/build.gradle @@ -102,6 +102,10 @@ dependencies { shadow library.java.snappy_java shadow library.java.joda_time implementation enforcedPlatform(library.java.google_cloud_platform_libraries_bom) + implementation library.java.gax + implementation library.java.google_cloud_kms + implementation library.java.proto_google_cloud_kms_v1 + implementation library.java.google_cloud_tink implementation library.java.google_cloud_secret_manager implementation library.java.proto_google_cloud_secret_manager_v1 implementation library.java.protobuf_java @@ -130,6 +134,8 @@ dependencies { shadowTest library.java.log4j2_api shadowTest library.java.jamm shadowTest 'com.google.cloud:google-cloud-secretmanager:2.75.0' + shadowTest 'com.google.cloud:google-cloud-kms:2.75.0' + shadowTest 'com.google.crypto.tink:tink:1.19.0' testRuntimeOnly library.java.slf4j_jdk14 } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/GcpHsmGeneratedSecret.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/GcpHsmGeneratedSecret.java new file mode 100644 index 000000000000..493330ad5561 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/GcpHsmGeneratedSecret.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.util; + +import com.google.api.gax.rpc.AlreadyExistsException; +import com.google.api.gax.rpc.NotFoundException; +import com.google.cloud.kms.v1.CryptoKeyName; +import com.google.cloud.kms.v1.EncryptResponse; +import com.google.cloud.kms.v1.KeyManagementServiceClient; +import com.google.cloud.secretmanager.v1.AccessSecretVersionResponse; +import com.google.cloud.secretmanager.v1.ProjectName; +import com.google.cloud.secretmanager.v1.Replication; +import com.google.cloud.secretmanager.v1.SecretManagerServiceClient; +import com.google.cloud.secretmanager.v1.SecretName; +import com.google.cloud.secretmanager.v1.SecretPayload; +import com.google.cloud.secretmanager.v1.SecretVersionName; +import com.google.crypto.tink.subtle.Hkdf; +import com.google.protobuf.ByteString; +import java.io.IOException; +import java.security.GeneralSecurityException; +import java.security.SecureRandom; +import java.util.Base64; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A {@link org.apache.beam.sdk.util.Secret} manager implementation that generates a secret using + * entropy from a GCP HSM key and stores it in Google Cloud Secret Manager. If the secret already + * exists, it will be retrieved. + */ +public class GcpHsmGeneratedSecret implements Secret { + private static final Logger LOG = LoggerFactory.getLogger(GcpHsmGeneratedSecret.class); + private final String projectId; + private final String locationId; + private final String keyRingId; + private final String keyId; + private final String secretId; + + private final SecureRandom random = new SecureRandom(); + + public GcpHsmGeneratedSecret( + String projectId, String locationId, String keyRingId, String keyId, String jobName) { + this.projectId = projectId; + this.locationId = locationId; + this.keyRingId = keyRingId; + this.keyId = keyId; + this.secretId = "HsmGeneratedSecret_" + jobName; + } + + /** + * Returns the secret as a byte array. Assumes that the current active service account has + * permissions to read the secret. + * + * @return The secret as a byte array. + */ + @Override + public byte[] getSecretBytes() { + try (SecretManagerServiceClient client = SecretManagerServiceClient.create()) { + SecretVersionName secretVersionName = SecretVersionName.of(projectId, secretId, "1"); + + try { + AccessSecretVersionResponse response = client.accessSecretVersion(secretVersionName); + return response.getPayload().getData().toByteArray(); + } catch (NotFoundException e) { + LOG.info( + "Secret version {} not found. Creating new secret and version.", + secretVersionName.toString()); + } + + ProjectName projectName = ProjectName.of(projectId); + SecretName secretName = SecretName.of(projectId, secretId); + try { + com.google.cloud.secretmanager.v1.Secret secret = + com.google.cloud.secretmanager.v1.Secret.newBuilder() + .setReplication( + Replication.newBuilder() + .setAutomatic(Replication.Automatic.newBuilder().build())) + .build(); + client.createSecret(projectName, secretId, secret); + } catch (AlreadyExistsException e) { + LOG.info("Secret {} already exists. Adding new version.", secretName.toString()); + } + + byte[] newKey = generateDek(); + + try { + // Always retrieve remote secret as source-of-truth in case another thread created it + AccessSecretVersionResponse response = client.accessSecretVersion(secretVersionName); + return response.getPayload().getData().toByteArray(); + } catch (NotFoundException e) { + LOG.info( + "Secret version {} not found after re-check. Creating new secret and version.", + secretVersionName.toString()); + } + + SecretPayload payload = + SecretPayload.newBuilder().setData(ByteString.copyFrom(newKey)).build(); + client.addSecretVersion(secretName, payload); + AccessSecretVersionResponse response = client.accessSecretVersion(secretVersionName); + return response.getPayload().getData().toByteArray(); + + } catch (IOException | GeneralSecurityException e) { + throw new RuntimeException("Failed to retrieve or create secret bytes", e); + } + } + + private byte[] generateDek() throws IOException, GeneralSecurityException { + int dekSize = 32; + try (KeyManagementServiceClient client = KeyManagementServiceClient.create()) { + // 1. Generate nonce_one. This doesn't need to have baked in randomness since the + // actual randomness comes from KMS. + byte[] nonceOne = new byte[dekSize]; + random.nextBytes(nonceOne); + + // 2. Encrypt to get nonce_two + CryptoKeyName keyName = CryptoKeyName.of(projectId, locationId, keyRingId, keyId); + EncryptResponse response = client.encrypt(keyName, ByteString.copyFrom(nonceOne)); + byte[] nonceTwo = response.getCiphertext().toByteArray(); + + // 3. Generate DK + byte[] dk = new byte[dekSize]; + random.nextBytes(dk); + + // 4. Derive DEK using HKDF + byte[] dek = Hkdf.computeHkdf("HmacSha256", dk, nonceTwo, new byte[0], dekSize); + + // 5. Base64 encode + return Base64.getUrlEncoder().encode(dek); + } + } + + /** + * Returns the project ID of the secret. + * + * @return The project ID as a String. + */ + public String getProjectId() { + return projectId; + } + + /** + * Returns the location ID of the secret. + * + * @return The location ID as a String. + */ + public String getLocationId() { + return locationId; + } + + /** + * Returns the key ring ID of the secret. + * + * @return The key ring ID as a String. + */ + public String getKeyRingId() { + return keyRingId; + } + + /** + * Returns the key ID of the secret. + * + * @return The key ID as a String. + */ + public String getKeyId() { + return keyId; + } + + /** + * Returns the secret ID of the secret. + * + * @return The secret ID as a String. + */ + public String getSecretId() { + return secretId; + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/Secret.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/Secret.java index a75e01c9543f..f8efde0dd44c 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/Secret.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/Secret.java @@ -23,6 +23,7 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; /** * A secret management interface used for handling sensitive data. @@ -70,16 +71,48 @@ static Secret parseSecretOption(String secretOption) { paramName, gcpSecretParams)); } } - String versionName = paramMap.get("version_name"); - if (versionName == null) { - throw new RuntimeException( - "version_name must contain a valid value for versionName parameter"); - } + String versionName = + Preconditions.checkNotNull( + paramMap.get("version_name"), + "version_name must contain a valid value for versionName parameter"); return new GcpSecret(versionName); + case "gcphsmgeneratedsecret": + Set gcpHsmGeneratedSecretParams = + new HashSet<>( + Arrays.asList("project_id", "location_id", "key_ring_id", "key_id", "job_name")); + for (String paramName : paramMap.keySet()) { + if (!gcpHsmGeneratedSecretParams.contains(paramName)) { + throw new RuntimeException( + String.format( + "Invalid secret parameter %s, GcpHsmGeneratedSecret only supports the following parameters: %s", + paramName, gcpHsmGeneratedSecretParams)); + } + } + String projectId = + Preconditions.checkNotNull( + paramMap.get("project_id"), + "project_id must contain a valid value for projectId parameter"); + String locationId = + Preconditions.checkNotNull( + paramMap.get("location_id"), + "location_id must contain a valid value for locationId parameter"); + String keyRingId = + Preconditions.checkNotNull( + paramMap.get("key_ring_id"), + "key_ring_id must contain a valid value for keyRingId parameter"); + String keyId = + Preconditions.checkNotNull( + paramMap.get("key_id"), "key_id must contain a valid value for keyId parameter"); + String jobName = + Preconditions.checkNotNull( + paramMap.get("job_name"), + "job_name must contain a valid value for jobName parameter"); + return new GcpHsmGeneratedSecret(projectId, locationId, keyRingId, keyId, jobName); default: throw new RuntimeException( String.format( - "Invalid secret type %s, currently only GcpSecret is supported", secretType)); + "Invalid secret type %s, currently only GcpSecret and GcpHsmGeneratedSecret are supported", + secretType)); } } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java index 31064470bd38..77195533ace3 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByEncryptedKeyTest.java @@ -39,6 +39,7 @@ import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.util.GcpHsmGeneratedSecret; import org.apache.beam.sdk.util.GcpSecret; import org.apache.beam.sdk.util.Secret; import org.apache.beam.sdk.values.KV; @@ -102,6 +103,9 @@ public void testGroupByKeyFakeSecret() { private static final String PROJECT_ID = "apache-beam-testing"; private static final String SECRET_ID = "gbek-test"; private static Secret gcpSecret; + private static Secret gcpHsmGeneratedSecret; + private static final String KEY_RING_ID = "gbek-test-key-ring"; + private static final String KEY_ID = "gbek-test-key"; @BeforeClass public static void setup() throws IOException { @@ -131,6 +135,45 @@ public static void setup() throws IOException { .build()); } gcpSecret = new GcpSecret(secretName.toString() + "/versions/latest"); + + try { + com.google.cloud.kms.v1.KeyManagementServiceClient kmsClient = + com.google.cloud.kms.v1.KeyManagementServiceClient.create(); + String locationId = "global"; + com.google.cloud.kms.v1.KeyRingName keyRingName = + com.google.cloud.kms.v1.KeyRingName.of(PROJECT_ID, locationId, KEY_RING_ID); + com.google.cloud.kms.v1.LocationName locationName = + com.google.cloud.kms.v1.LocationName.of(PROJECT_ID, locationId); + try { + kmsClient.getKeyRing(keyRingName); + } catch (Exception e) { + kmsClient.createKeyRing( + locationName, KEY_RING_ID, com.google.cloud.kms.v1.KeyRing.newBuilder().build()); + } + + com.google.cloud.kms.v1.CryptoKeyName keyName = + com.google.cloud.kms.v1.CryptoKeyName.of(PROJECT_ID, locationId, KEY_RING_ID, KEY_ID); + try { + kmsClient.getCryptoKey(keyName); + } catch (Exception e) { + com.google.cloud.kms.v1.CryptoKey key = + com.google.cloud.kms.v1.CryptoKey.newBuilder() + .setPurpose(com.google.cloud.kms.v1.CryptoKey.CryptoKeyPurpose.ENCRYPT_DECRYPT) + .build(); + kmsClient.createCryptoKey(keyRingName, KEY_ID, key); + } + gcpHsmGeneratedSecret = + new GcpHsmGeneratedSecret( + PROJECT_ID, + locationId, + KEY_RING_ID, + KEY_ID, + String.format("gbek-test-job-%d", new SecureRandom().nextInt(10000))); + // Validate we have crypto permissions or skip these tests. + gcpHsmGeneratedSecret.getSecretBytes(); + } catch (Exception e) { + gcpHsmGeneratedSecret = null; + } } @AfterClass @@ -183,6 +226,43 @@ public void testGroupByKeyGcpSecretThrows() { assertThrows(RuntimeException.class, () -> p.run()); } + @Test + @Category(NeedsRunner.class) + public void testGroupByKeyGcpHsmGeneratedSecret() { + if (gcpHsmGeneratedSecret == null) { + return; + } + List> ungroupedPairs = + Arrays.asList( + KV.of(null, 3), + KV.of("k1", 3), + KV.of("k5", Integer.MAX_VALUE), + KV.of("k5", Integer.MIN_VALUE), + KV.of("k2", 66), + KV.of("k1", 4), + KV.of(null, 5), + KV.of("k2", -33), + KV.of("k3", 0)); + + PCollection> input = + p.apply( + Create.of(ungroupedPairs) + .withCoder(KvCoder.of(NullableCoder.of(StringUtf8Coder.of()), VarIntCoder.of()))); + + PCollection>> output = + input.apply(GroupByEncryptedKey.create(gcpHsmGeneratedSecret)); + + PAssert.that(output.apply("Sort", MapElements.via(new SortValues()))) + .containsInAnyOrder( + KV.of("k1", Arrays.asList(3, 4)), + KV.of(null, Arrays.asList(3, 5)), + KV.of("k5", Arrays.asList(Integer.MIN_VALUE, Integer.MAX_VALUE)), + KV.of("k2", Arrays.asList(-33, 66)), + KV.of("k3", Arrays.asList(0))); + + p.run(); + } + private static class SortValues extends SimpleFunction>, KV>> { @Override diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyIT.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyIT.java index 1c8168a42a03..431bdf448bea 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyIT.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupByKeyIT.java @@ -17,6 +17,10 @@ */ package org.apache.beam.sdk.transforms; +import com.google.cloud.kms.v1.CryptoKey; +import com.google.cloud.kms.v1.CryptoKeyName; +import com.google.cloud.kms.v1.KeyManagementServiceClient; +import com.google.cloud.kms.v1.KeyRingName; import com.google.cloud.secretmanager.v1.ProjectName; import com.google.cloud.secretmanager.v1.SecretManagerServiceClient; import com.google.cloud.secretmanager.v1.SecretName; @@ -33,6 +37,7 @@ import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.util.GcpHsmGeneratedSecret; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.junit.AfterClass; @@ -51,7 +56,11 @@ public class GroupByKeyIT { private static final String PROJECT_ID = "apache-beam-testing"; private static final String SECRET_ID = "gbek-test"; private static String gcpSecretVersionName; + private static String gcpHsmSecretOption; private static String secretId; + private static final String KEY_RING_ID = "gbek-it-key-ring"; + private static final String KEY_ID = "gbek-it-key"; + private static final String LOCATION_ID = "global"; @BeforeClass public static void setup() throws IOException { @@ -88,6 +97,34 @@ public static void setup() throws IOException { .build()); } gcpSecretVersionName = secretName.toString() + "/versions/latest"; + + try { + KeyManagementServiceClient kmsClient = KeyManagementServiceClient.create(); + KeyRingName keyRingName = KeyRingName.of(PROJECT_ID, LOCATION_ID, KEY_RING_ID); + com.google.cloud.kms.v1.LocationName locationName = + com.google.cloud.kms.v1.LocationName.of(PROJECT_ID, LOCATION_ID); + try { + kmsClient.getKeyRing(keyRingName); + } catch (Exception e) { + kmsClient.createKeyRing( + locationName, KEY_RING_ID, com.google.cloud.kms.v1.KeyRing.newBuilder().build()); + } + + CryptoKeyName keyName = CryptoKeyName.of(PROJECT_ID, LOCATION_ID, KEY_RING_ID, KEY_ID); + try { + kmsClient.getCryptoKey(keyName); + } catch (Exception e) { + CryptoKey key = + CryptoKey.newBuilder().setPurpose(CryptoKey.CryptoKeyPurpose.ENCRYPT_DECRYPT).build(); + kmsClient.createCryptoKey(keyRingName, KEY_ID, key); + } + gcpHsmSecretOption = + String.format( + "type:gcphsmgeneratedsecret;project_id:%s;location_id:%s;key_ring_id:%s;key_id:%s;job_name:%s", + PROJECT_ID, LOCATION_ID, KEY_RING_ID, KEY_ID, secretId); + } catch (Exception e) { + gcpHsmSecretOption = null; + } } @AfterClass @@ -135,6 +172,81 @@ public void testGroupByKeyWithValidGcpSecretOption() throws Exception { p.run(); } + @Test + public void testGroupByKeyWithValidGcpHsmGeneratedSecretOption() throws Exception { + if (gcpHsmSecretOption == null) { + // Skip test if we couldn't set up KMS + return; + } + PipelineOptions options = TestPipeline.testingPipelineOptions(); + options.setGbek(gcpHsmSecretOption); + Pipeline p = Pipeline.create(options); + List> ungroupedPairs = + Arrays.asList( + KV.of("k1", 3), + KV.of("k5", Integer.MAX_VALUE), + KV.of("k5", Integer.MIN_VALUE), + KV.of("k2", 66), + KV.of("k1", 4), + KV.of("k2", -33), + KV.of("k3", 0)); + + PCollection> input = + p.apply( + Create.of(ungroupedPairs) + .withCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()))); + + PCollection>> output = input.apply(GroupByKey.create()); + + PAssert.that(output) + .containsInAnyOrder( + KV.of("k1", Arrays.asList(3, 4)), + KV.of("k5", Arrays.asList(Integer.MAX_VALUE, Integer.MIN_VALUE)), + KV.of("k2", Arrays.asList(66, -33)), + KV.of("k3", Arrays.asList(0))); + + p.run(); + } + + @Test + public void testGroupByKeyWithExistingGcpHsmGeneratedSecretOption() throws Exception { + if (gcpHsmSecretOption == null) { + // Skip test if we couldn't set up KMS + return; + } + // Create the secret beforehand + new GcpHsmGeneratedSecret(PROJECT_ID, "global", KEY_RING_ID, KEY_ID, secretId).getSecretBytes(); + + PipelineOptions options = TestPipeline.testingPipelineOptions(); + options.setGbek(gcpHsmSecretOption); + Pipeline p = Pipeline.create(options); + List> ungroupedPairs = + Arrays.asList( + KV.of("k1", 3), + KV.of("k5", Integer.MAX_VALUE), + KV.of("k5", Integer.MIN_VALUE), + KV.of("k2", 66), + KV.of("k1", 4), + KV.of("k2", -33), + KV.of("k3", 0)); + + PCollection> input = + p.apply( + Create.of(ungroupedPairs) + .withCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of()))); + + PCollection>> output = input.apply(GroupByKey.create()); + + PAssert.that(output) + .containsInAnyOrder( + KV.of("k1", Arrays.asList(3, 4)), + KV.of("k5", Arrays.asList(Integer.MAX_VALUE, Integer.MIN_VALUE)), + KV.of("k2", Arrays.asList(66, -33)), + KV.of("k3", Arrays.asList(0))); + + p.run(); + } + @Test public void testGroupByKeyWithInvalidGcpSecretOption() throws Exception { if (gcpSecretVersionName == null) { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/SecretTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/SecretTest.java index dd4b125d73fe..0acfa3963462 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/SecretTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/SecretTest.java @@ -37,6 +37,20 @@ public void testParseSecretOptionWithValidGcpSecret() { assertEquals("my_secret/versions/latest", ((GcpSecret) secret).getVersionName()); } + @Test + public void testParseSecretOptionWithValidGcpHsmGeneratedSecret() { + String secretOption = + "type:gcphsmgeneratedsecret;project_id:my-project;location_id:global;key_ring_id:my-key-ring;key_id:my-key;job_name:my-job"; + Secret secret = Secret.parseSecretOption(secretOption); + assertTrue(secret instanceof GcpHsmGeneratedSecret); + GcpHsmGeneratedSecret hsmSecret = (GcpHsmGeneratedSecret) secret; + assertEquals("my-project", hsmSecret.getProjectId()); + assertEquals("global", hsmSecret.getLocationId()); + assertEquals("my-key-ring", hsmSecret.getKeyRingId()); + assertEquals("my-key", hsmSecret.getKeyId()); + assertEquals("HsmGeneratedSecret_my-job", hsmSecret.getSecretId()); + } + @Test public void testParseSecretOptionWithMissingType() { String secretOption = "version_name:my_secret/versions/latest"; @@ -50,9 +64,7 @@ public void testParseSecretOptionWithUnsupportedType() { String secretOption = "type:unsupported;version_name:my_secret/versions/latest"; Exception exception = assertThrows(RuntimeException.class, () -> Secret.parseSecretOption(secretOption)); - assertEquals( - "Invalid secret type unsupported, currently only GcpSecret is supported", - exception.getMessage()); + assertTrue(exception.getMessage().contains("Invalid secret type unsupported")); } @Test diff --git a/sdks/python/apache_beam/transforms/core_it_test.py b/sdks/python/apache_beam/transforms/core_it_test.py index 18ae3f30f574..2cdb770b5972 100644 --- a/sdks/python/apache_beam/transforms/core_it_test.py +++ b/sdks/python/apache_beam/transforms/core_it_test.py @@ -38,6 +38,11 @@ except ImportError: secretmanager = None # type: ignore[assignment] +try: + from google.cloud import kms +except ImportError: + kms = None # type: ignore[assignment] + class GbekIT(unittest.TestCase): @classmethod @@ -74,6 +79,42 @@ def setUpClass(cls): cls.gcp_secret = GcpSecret(version_name) cls.secret_option = f'type:GcpSecret;version_name:{version_name}' + if kms is not None: + cls.kms_client = kms.KeyManagementServiceClient() + cls.location_id = 'global' + py_version = f'_py{sys.version_info.major}{sys.version_info.minor}' + secret_postfix = datetime.now().strftime('%m%d_%H%M%S') + py_version + cls.key_ring_id = 'gbekit_key_ring_tests' + cls.key_ring_path = cls.kms_client.key_ring_path( + cls.project_id, cls.location_id, cls.key_ring_id) + try: + cls.kms_client.get_key_ring(request={'name': cls.key_ring_path}) + except Exception: + parent = f'projects/{cls.project_id}/locations/{cls.location_id}' + cls.kms_client.create_key_ring( + request={ + 'parent': parent, + 'key_ring_id': cls.key_ring_id, + }) + cls.key_id = 'gbekit_key_tests' + cls.key_path = cls.kms_client.crypto_key_path( + cls.project_id, cls.location_id, cls.key_ring_id, cls.key_id) + try: + cls.kms_client.get_crypto_key(request={'name': cls.key_path}) + except Exception: + cls.kms_client.create_crypto_key( + request={ + 'parent': cls.key_ring_path, + 'crypto_key_id': cls.key_id, + 'crypto_key': { + 'purpose': kms.CryptoKey.CryptoKeyPurpose.ENCRYPT_DECRYPT + } + }) + cls.hsm_secret_option = ( + f'type:GcpHsmGeneratedSecret;project_id:{cls.project_id};' + f'location_id:{cls.location_id};key_ring_id:{cls.key_ring_id};' + f'key_id:{cls.key_id};job_name:{secret_postfix}') + @classmethod def tearDownClass(cls): if secretmanager is not None: @@ -94,6 +135,22 @@ def test_gbk_with_gbek_it(self): pipeline.run().wait_until_finish() + @pytest.mark.it_postcommit + @unittest.skipIf(secretmanager is None, 'GCP dependencies are not installed') + @unittest.skipIf(kms is None, 'GCP dependencies are not installed') + def test_gbk_with_gbek_hsm_it(self): + pipeline = TestPipeline(is_integration_test=True) + pipeline.options.view_as(SetupOptions).gbek = self.hsm_secret_option + + pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2), ('b', 3), + ('c', 4)]) + result = (pcoll_1) | beam.GroupByKey() + sorted_result = result | beam.Map(lambda x: (x[0], sorted(x[1]))) + assert_that( + sorted_result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))])) + + pipeline.run().wait_until_finish() + @pytest.mark.it_postcommit @unittest.skipIf(secretmanager is None, 'GCP dependencies are not installed') def test_combineValues_with_gbek_it(self): diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index ba79d4ddf31c..fbaab6b4ebbb 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -361,15 +361,20 @@ def parse_secret_option(secret) -> 'Secret': secret_type = param_map['type'].lower() del param_map['type'] - secret_class = None + secret_class = Secret secret_params = None if secret_type == 'gcpsecret': - secret_class = GcpSecret + secret_class = GcpSecret # type: ignore[assignment] secret_params = ['version_name'] + elif secret_type == 'gcphsmgeneratedsecret': + secret_class = GcpHsmGeneratedSecret # type: ignore[assignment] + secret_params = [ + 'project_id', 'location_id', 'key_ring_id', 'key_id', 'job_name' + ] else: raise ValueError( f'Invalid secret type {secret_type}, currently only ' - 'GcpSecret is supported') + 'GcpSecret and GcpHsmGeneratedSecret are supported') for param_name in param_map.keys(): if param_name not in secret_params: @@ -413,6 +418,155 @@ def __eq__(self, secret): return self._version_name == getattr(secret, '_version_name', None) +class GcpHsmGeneratedSecret(Secret): + """A secret manager implementation that generates a secret using a GCP HSM key + and stores it in Google Cloud Secret Manager. If the secret already exists, + it will be retrieved. + """ + def __init__( + self, + project_id: str, + location_id: str, + key_ring_id: str, + key_id: str, + job_name: str): + """Initializes a GcpHsmGeneratedSecret object. + + Args: + project_id: The GCP project ID. + location_id: The GCP location ID for the HSM key. + key_ring_id: The ID of the KMS key ring. + key_id: The ID of the KMS key. + job_name: The name of the job, used to generate a unique secret name. + """ + self._project_id = project_id + self._location_id = location_id + self._key_ring_id = key_ring_id + self._key_id = key_id + self._secret_version_name = f'HsmGeneratedSecret_{job_name}' + + def get_secret_bytes(self) -> bytes: + """Retrieves the secret bytes. + + If the secret version already exists in Secret Manager, it is retrieved. + Otherwise, a new secret and version are created. The new secret is + generated using the HSM key. + + Returns: + The secret as a byte string. + """ + try: + from google.api_core import exceptions as api_exceptions + from google.cloud import secretmanager + client = secretmanager.SecretManagerServiceClient() + + project_path = f"projects/{self._project_id}" + secret_path = f"{project_path}/secrets/{self._secret_version_name}" + # Since we may generate multiple versions when doing this on workers, + # just always take the first version added to maintain consistency. + secret_version_path = f"{secret_path}/versions/1" + + try: + response = client.access_secret_version( + request={"name": secret_version_path}) + return response.payload.data + except api_exceptions.NotFound: + # Don't bother logging yet, we'll only log if we actually add the + # secret version below + pass + + try: + client.create_secret( + request={ + "parent": project_path, + "secret_id": self._secret_version_name, + "secret": { + "replication": { + "automatic": {} + } + }, + }) + except api_exceptions.AlreadyExists: + # Don't bother logging yet, we'll only log if we actually add the + # secret version below + pass + + new_key = self.generate_dek() + try: + # Try one more time in case it was created while we were generating the + # DEK. + response = client.access_secret_version( + request={"name": secret_version_path}) + return response.payload.data + except api_exceptions.NotFound: + logging.info( + "Secret version %s not found. " + "Creating new secret and version.", + secret_version_path) + client.add_secret_version( + request={ + "parent": secret_path, "payload": { + "data": new_key + } + }) + response = client.access_secret_version( + request={"name": secret_version_path}) + return response.payload.data + + except Exception as e: + raise RuntimeError( + f'Failed to retrieve or create secret bytes for secret ' + f'{self._secret_version_name} with exception {e}') + + def generate_dek(self, dek_size: int = 32) -> bytes: + """Generates a new Data Encryption Key (DEK) using an HSM-backed key. + + This function follows a key derivation process that incorporates entropy + from the HSM-backed key into the nonce used for key derivation. + + Args: + dek_size: The size of the DEK to generate. + + Returns: + A new DEK of the specified size, url-safe base64-encoded. + """ + try: + import base64 + import os + + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives.kdf.hkdf import HKDF + from google.cloud import kms + + # 1. Generate a random nonce (nonce_one) + nonce_one = os.urandom(dek_size) + + # 2. Use the HSM-backed key to encrypt nonce_one to create nonce_two + kms_client = kms.KeyManagementServiceClient() + key_path = kms_client.crypto_key_path( + self._project_id, self._location_id, self._key_ring_id, self._key_id) + response = kms_client.encrypt( + request={ + 'name': key_path, 'plaintext': nonce_one + }) + nonce_two = response.ciphertext + + # 3. Generate a Derivation Key (DK) + dk = os.urandom(dek_size) + + # 4. Use a KDF to derive the DEK using DK and nonce_two + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=dek_size, + salt=nonce_two, + info=None, + ) + dek = hkdf.derive(dk) + return base64.urlsafe_b64encode(dek) + except Exception as e: + raise RuntimeError(f'Failed to generate DEK with exception {e}') + + class _EncryptMessage(DoFn): """A DoFn that encrypts the key and value of each element.""" def __init__( diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 34e251fad1c7..dd5e19519faf 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -72,6 +72,7 @@ from apache_beam.transforms.core import FlatMapTuple from apache_beam.transforms.trigger import AfterCount from apache_beam.transforms.trigger import Repeatedly +from apache_beam.transforms.util import GcpHsmGeneratedSecret from apache_beam.transforms.util import GcpSecret from apache_beam.transforms.util import Secret from apache_beam.transforms.window import FixedWindows @@ -439,6 +440,124 @@ def test_gbek_gcp_secret_manager_throws(self): result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))])) +@unittest.skipIf(secretmanager is None, 'GCP dependencies are not installed') +class GcpHsmGeneratedSecretTest(unittest.TestCase): + def setUp(self): + self.mock_secret_manager_client = mock.MagicMock() + self.mock_kms_client = mock.MagicMock() + + # Patch the clients + self.secretmanager_patcher = mock.patch( + 'google.cloud.secretmanager.SecretManagerServiceClient', + return_value=self.mock_secret_manager_client) + self.kms_patcher = mock.patch( + 'google.cloud.kms.KeyManagementServiceClient', + return_value=self.mock_kms_client) + self.os_urandom_patcher = mock.patch('os.urandom', return_value=b'0' * 32) + self.hkdf_patcher = mock.patch( + 'cryptography.hazmat.primitives.kdf.hkdf.HKDF.derive', + return_value=b'derived_key') + + self.secretmanager_patcher.start() + self.kms_patcher.start() + self.os_urandom_patcher.start() + self.hkdf_patcher.start() + + def tearDown(self): + self.secretmanager_patcher.stop() + self.kms_patcher.stop() + self.os_urandom_patcher.stop() + self.hkdf_patcher.stop() + + def test_happy_path_secret_creation(self): + from google.api_core import exceptions as api_exceptions + + project_id = 'test-project' + location_id = 'global' + key_ring_id = 'test-key-ring' + key_id = 'test-key' + job_name = 'test-job' + + secret = GcpHsmGeneratedSecret( + project_id, location_id, key_ring_id, key_id, job_name) + + # Mock responses for secret creation path + self.mock_secret_manager_client.access_secret_version.side_effect = [ + api_exceptions.NotFound('not found'), # first check + api_exceptions.NotFound('not found'), # second check + mock.MagicMock(payload=mock.MagicMock(data=b'derived_key')) + ] + self.mock_kms_client.encrypt.return_value = mock.MagicMock( + ciphertext=b'encrypted_nonce') + + secret_bytes = secret.get_secret_bytes() + self.assertEqual(secret_bytes, b'derived_key') + + # Assertions on mocks + secret_version_path = ( + f'projects/{project_id}/secrets/{secret._secret_version_name}' + '/versions/1') + self.mock_secret_manager_client.access_secret_version.assert_any_call( + request={'name': secret_version_path}) + self.assertEqual( + self.mock_secret_manager_client.access_secret_version.call_count, 3) + self.mock_secret_manager_client.create_secret.assert_called_once() + self.mock_kms_client.encrypt.assert_called_once() + self.mock_secret_manager_client.add_secret_version.assert_called_once() + + def test_secret_already_exists(self): + from google.api_core import exceptions as api_exceptions + + project_id = 'test-project' + location_id = 'global' + key_ring_id = 'test-key-ring' + key_id = 'test-key' + job_name = 'test-job' + + secret = GcpHsmGeneratedSecret( + project_id, location_id, key_ring_id, key_id, job_name) + + # Mock responses for secret creation path + self.mock_secret_manager_client.access_secret_version.side_effect = [ + api_exceptions.NotFound('not found'), + api_exceptions.NotFound('not found'), + mock.MagicMock(payload=mock.MagicMock(data=b'derived_key')) + ] + self.mock_secret_manager_client.create_secret.side_effect = ( + api_exceptions.AlreadyExists('exists')) + self.mock_kms_client.encrypt.return_value = mock.MagicMock( + ciphertext=b'encrypted_nonce') + + secret_bytes = secret.get_secret_bytes() + self.assertEqual(secret_bytes, b'derived_key') + + # Assertions on mocks + self.mock_secret_manager_client.create_secret.assert_called_once() + self.mock_secret_manager_client.add_secret_version.assert_called_once() + + def test_secret_version_already_exists(self): + project_id = 'test-project' + location_id = 'global' + key_ring_id = 'test-key-ring' + key_id = 'test-key' + job_name = 'test-job' + + secret = GcpHsmGeneratedSecret( + project_id, location_id, key_ring_id, key_id, job_name) + + self.mock_secret_manager_client.access_secret_version.return_value = ( + mock.MagicMock(payload=mock.MagicMock(data=b'existing_dek'))) + + secret_bytes = secret.get_secret_bytes() + self.assertEqual(secret_bytes, b'existing_dek') + + # Assertions + self.mock_secret_manager_client.access_secret_version.assert_called_once() + self.mock_secret_manager_client.create_secret.assert_not_called() + self.mock_secret_manager_client.add_secret_version.assert_not_called() + self.mock_kms_client.encrypt.assert_not_called() + + class FakeClock(object): def __init__(self, now=time.time()): self._now = now diff --git a/sdks/python/setup.py b/sdks/python/setup.py index ef58b4f9c760..c74afbb52d37 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -483,6 +483,7 @@ def get_portability_package_data(): 'google-cloud-spanner>=3.0.0,<4', # GCP Packages required by ML functionality 'google-cloud-dlp>=3.0.0,<4', + 'google-cloud-kms>=3.0.0,<4', 'google-cloud-language>=2.0,<3', 'google-cloud-secret-manager>=2.0,<3', 'google-cloud-videointelligence>=2.0,<3',