diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/CipherPool.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/CipherPool.java index 24b6b77b2f..23a184df46 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/CipherPool.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/CipherPool.java @@ -47,4 +47,8 @@ public static Cipher borrowCipher(@Nonnull String cipherName) throws GeneralSecu public static void returnCipher(@Nonnull Cipher cipher) { MAPPED_POOL.offer(cipher.getAlgorithm(), cipher); } + + public static void invalidateAll() { + MAPPED_POOL.invalidateAll(); + } } diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/MappedPool.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/MappedPool.java index 4ddd23fad5..dca9d92ad3 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/MappedPool.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/MappedPool.java @@ -103,6 +103,13 @@ public int getPoolSize(K key) { return queue == null ? 0 : queue.size(); } + /** + * Invalidate all entries in the pool. + */ + public void invalidateAll() { + pool.invalidateAll(); + } + /** * Function with Exceptions to provide the pool. * diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializer.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializer.java index a21de27082..6e2a89a31d 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializer.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializer.java @@ -27,16 +27,13 @@ import com.apple.foundationdb.record.logging.LogMessageKeys; import com.apple.foundationdb.record.metadata.RecordType; import com.apple.foundationdb.tuple.Tuple; -import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.Message; -import com.apple.foundationdb.annotation.SpotBugsSuppressWarnings; import javax.annotation.Nonnull; import javax.annotation.Nullable; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.security.GeneralSecurityException; -import java.util.Arrays; import java.util.concurrent.ThreadLocalRandom; import java.util.zip.DataFormatException; import java.util.zip.Deflater; @@ -50,7 +47,8 @@ * added in the future. * *

- * This serializer will begin each serialized string with a one-byte prefix + * This serializer will begin each serialized string with a prefix + * (see {@link TransformedRecordSerializerPrefix} for details) * containing information about which transformations were performed. This * way, when deserializing, it can detect which transformations were applied * so it knows which ones it needs to use to restore the original record. @@ -78,15 +76,6 @@ */ @API(API.Status.UNSTABLE) public class TransformedRecordSerializer implements RecordSerializer { - @VisibleForTesting - protected static final int ENCODING_ENCRYPTED = 1; - @VisibleForTesting - protected static final int ENCODING_CLEAR = 2; - @VisibleForTesting - protected static final int ENCODING_COMPRESSED = 4; - // TODO: Can remove this after transition to write everything with _CLEAR. - protected static final int ENCODING_PROTO_MESSAGE_FIELD = 0x02; - protected static final int ENCODING_PROTO_TYPE_MASK = 0x07; protected static final int DEFAULT_COMPRESSION_LEVEL = Deflater.BEST_COMPRESSION; protected static final int MIN_COMPRESSION_VERSION = 1; protected static final int MAX_COMPRESSION_VERSION = 1; @@ -110,62 +99,16 @@ protected TransformedRecordSerializer(@Nonnull RecordSerializer inner, this.writeValidationRatio = writeValidationRatio; } - @SpotBugsSuppressWarnings("EI_EXPOSE_REP") - protected static class TransformState { - public boolean compressed; - public boolean encrypted; - - @Nonnull public byte[] data; - public int offset; - public int length; - - public TransformState(@Nonnull byte[] data) { - this(data, 0, data.length); - } - - public TransformState(@Nonnull byte[] data, int offset, int length) { - this.compressed = false; - this.encrypted = false; - this.data = data; - this.offset = offset; - this.length = length; - } - - @Nonnull - public byte[] getDataArray() { - if (offset == 0 && length == data.length) { - return data; - } else { - byte[] newData = Arrays.copyOfRange(data, offset, offset + length); - offset = 0; - length = newData.length; - data = newData; - return newData; - } - } - - - public void setDataArray(@Nonnull byte[] data) { - setDataArray(data, 0, data.length); - } - - public void setDataArray(@Nonnull byte[] data, int offset, int length) { - this.data = data; - this.offset = offset; - this.length = length; - } - } - - protected void compress(@Nonnull TransformState state, @Nullable StoreTimer timer) { + protected void compress(@Nonnull TransformedRecordSerializerState state, @Nullable StoreTimer timer) { long startTime = System.nanoTime(); - increment(timer, Counts.RECORD_BYTES_BEFORE_COMPRESSION, state.length); + increment(timer, Counts.RECORD_BYTES_BEFORE_COMPRESSION, state.getLength()); // compressed data stores 5 bytes of header info. Hence, it is only fruitful to compress if the uncompressed data // has more than 5 bytes otherwise the compressed data will always be more than the original. - if (state.length > 5) { + if (state.getLength() > 5) { // Compressed bytes have 5 bytes of prefixed information about the compression state. - byte[] compressed = new byte[state.length]; + byte[] compressed = new byte[state.getLength()]; // Actually compress. If we end up filling the buffer, then just // return the uncompressed value because it's pointless to compress @@ -173,31 +116,31 @@ protected void compress(@Nonnull TransformState state, @Nullable StoreTimer time Deflater compressor = new Deflater(compressionLevel); int compressedLength; try { - compressor.setInput(state.data, state.offset, state.length); + compressor.setInput(state.getData(), state.getOffset(), state.getLength()); compressor.finish(); // necessary to include checksum compressedLength = compressor.deflate(compressed, 5, compressed.length - 5, Deflater.FULL_FLUSH); } finally { compressor.end(); } if (compressedLength == compressed.length - 5) { - increment(timer, Counts.RECORD_BYTES_AFTER_COMPRESSION, state.length); - state.compressed = false; + increment(timer, Counts.RECORD_BYTES_AFTER_COMPRESSION, state.getLength()); + state.setCompressed(false); } else { // Write compression version number and uncompressed size as these // meta-data are needed when decompressing. compressed[0] = (byte)MAX_COMPRESSION_VERSION; - ByteBuffer.wrap(compressed, 1, 4).order(ByteOrder.BIG_ENDIAN).putInt(state.length); - state.compressed = true; + ByteBuffer.wrap(compressed, 1, 4).order(ByteOrder.BIG_ENDIAN).putInt(state.getLength()); + state.setCompressed(true); increment(timer, Counts.RECORD_BYTES_AFTER_COMPRESSION, compressedLength + 5); state.setDataArray(compressed, 0, compressedLength + 5); } } else { - increment(timer, Counts.RECORD_BYTES_AFTER_COMPRESSION, state.length); + increment(timer, Counts.RECORD_BYTES_AFTER_COMPRESSION, state.getLength()); } if (timer != null) { timer.recordSinceNanoTime(Events.COMPRESS_SERIALIZED_RECORD, startTime); - if (!state.compressed) { + if (!state.isCompressed()) { timer.increment(Counts.ESCHEW_RECORD_COMPRESSION); } } @@ -209,7 +152,7 @@ private void increment(@Nullable StoreTimer timer, StoreTimer.Count counter, int } } - protected void encrypt(@Nonnull TransformState state, @Nullable StoreTimer timer) throws GeneralSecurityException { + protected void encrypt(@Nonnull TransformedRecordSerializerState state, @Nullable StoreTimer timer) throws GeneralSecurityException { throw new RecordSerializationException("this serializer cannot encrypt"); } @@ -225,7 +168,7 @@ public byte[] serialize(@Nonnull RecordMetaData metaData, @Nullable StoreTimer timer) { byte[] innerSerialized = inner.serialize(metaData, recordType, rec, timer); - TransformState state = new TransformState(innerSerialized); + TransformedRecordSerializerState state = new TransformedRecordSerializerState(innerSerialized); if (compressWhenSerializing) { compress(state, timer); @@ -241,50 +184,34 @@ public byte[] serialize(@Nonnull RecordMetaData metaData, } } - int code; - if (state.compressed || state.encrypted) { - code = 0; - if (state.compressed) { - code = code | ENCODING_COMPRESSED; - } - if (state.encrypted) { - code = code | ENCODING_ENCRYPTED; - } - } else { - code = ENCODING_CLEAR; - } - - int size = state.length + 1; - byte[] serialized = new byte[size]; - serialized[0] = (byte) code; - System.arraycopy(state.data, state.offset, serialized, 1, state.length); + TransformedRecordSerializerPrefix.encodePrefix(state); if (shouldValidateSerialization()) { - validateSerialization(metaData, recordType, rec, serialized, timer); + validateSerialization(metaData, recordType, rec, state.getDataArray(), timer); } - return serialized; + return state.getDataArray(); } - protected void decompress(@Nonnull TransformState state, @Nullable StoreTimer timer) throws DataFormatException { + protected void decompress(@Nonnull TransformedRecordSerializerState state, @Nullable StoreTimer timer) throws DataFormatException { final long startTime = System.nanoTime(); // At the moment, there is only one compression version, so // we after we've verified it is in the right range, we // can just move on. If we ever introduce a new format version, // we will need to make this code more complicated. - int compressionVersion = state.data[state.offset]; + int compressionVersion = state.getData()[state.getOffset()]; if (compressionVersion < MIN_COMPRESSION_VERSION || compressionVersion > MAX_COMPRESSION_VERSION) { throw new RecordSerializationException("unknown compression version") .addLogInfo("compressionVersion", compressionVersion); } - int decompressedLength = ByteBuffer.wrap(state.data, state.offset + 1, 4).order(ByteOrder.BIG_ENDIAN).getInt(); + int decompressedLength = ByteBuffer.wrap(state.getData(), state.getOffset() + 1, 4).order(ByteOrder.BIG_ENDIAN).getInt(); byte[] decompressed = new byte[decompressedLength]; Inflater decompressor = new Inflater(); try { - decompressor.setInput(state.data, state.offset + 5, state.length - 5); + decompressor.setInput(state.getData(), state.getOffset() + 5, state.getLength() - 5); int actualDecompressedSize = decompressor.inflate(decompressed); if (actualDecompressedSize < decompressedLength) { throw new RecordSerializationException("decompressed record too small") @@ -305,7 +232,7 @@ protected void decompress(@Nonnull TransformState state, @Nullable StoreTimer ti } } - protected void decrypt(@Nonnull TransformState state, @Nullable StoreTimer timer) throws GeneralSecurityException { + protected void decrypt(@Nonnull TransformedRecordSerializerState state, @Nullable StoreTimer timer) throws GeneralSecurityException { throw new RecordSerializationException("this serializer cannot decrypt"); } @@ -316,52 +243,35 @@ public M deserialize(@Nonnull RecordMetaData metaData, @Nonnull Tuple primaryKey, @Nonnull byte[] serialized, @Nullable StoreTimer timer) { - int encoding = serialized[0]; - if (encoding != ENCODING_CLEAR && (encoding & ENCODING_PROTO_TYPE_MASK) == ENCODING_PROTO_MESSAGE_FIELD) { - // TODO: Can remove this after transition to write everything with _CLEAR. + TransformedRecordSerializerState state = new TransformedRecordSerializerState(serialized); + if (!TransformedRecordSerializerPrefix.decodePrefix(state, primaryKey)) { return inner.deserialize(metaData, primaryKey, serialized, timer); - } else { - TransformState state = new TransformState(serialized, 1, serialized.length - 1); - if (encoding != ENCODING_CLEAR) { - if ((encoding & ENCODING_COMPRESSED) == ENCODING_COMPRESSED) { - state.compressed = true; - } - if ((encoding & ENCODING_ENCRYPTED) == ENCODING_ENCRYPTED) { - state.encrypted = true; - } - if ((encoding & ~(ENCODING_COMPRESSED | ENCODING_ENCRYPTED)) != 0) { - throw new RecordSerializationException("unrecognized transformation encoding") - .addLogInfo(LogMessageKeys.META_DATA_VERSION, metaData.getVersion()) - .addLogInfo(LogMessageKeys.PRIMARY_KEY, primaryKey) - .addLogInfo("encoding", encoding); - } - } - if (state.encrypted) { - try { - decrypt(state, timer); - } catch (RecordCoreException ex) { - throw ex.addLogInfo(LogMessageKeys.META_DATA_VERSION, metaData.getVersion()) - .addLogInfo(LogMessageKeys.PRIMARY_KEY, primaryKey); - } catch (GeneralSecurityException ex) { - throw new RecordSerializationException("decryption error", ex) - .addLogInfo(LogMessageKeys.META_DATA_VERSION, metaData.getVersion()) - .addLogInfo(LogMessageKeys.PRIMARY_KEY, primaryKey); - } + } + if (state.isEncrypted()) { + try { + decrypt(state, timer); + } catch (RecordCoreException ex) { + throw ex.addLogInfo(LogMessageKeys.META_DATA_VERSION, metaData.getVersion()) + .addLogInfo(LogMessageKeys.PRIMARY_KEY, primaryKey); + } catch (GeneralSecurityException ex) { + throw new RecordSerializationException("decryption error", ex) + .addLogInfo(LogMessageKeys.META_DATA_VERSION, metaData.getVersion()) + .addLogInfo(LogMessageKeys.PRIMARY_KEY, primaryKey); } - if (state.compressed) { - try { - decompress(state, timer); - } catch (RecordCoreException ex) { - throw ex.addLogInfo(LogMessageKeys.META_DATA_VERSION, metaData.getVersion()) - .addLogInfo(LogMessageKeys.PRIMARY_KEY, primaryKey); - } catch (DataFormatException ex) { - throw new RecordSerializationException("decompression error", ex) - .addLogInfo(LogMessageKeys.META_DATA_VERSION, metaData.getVersion()) - .addLogInfo(LogMessageKeys.PRIMARY_KEY, primaryKey); - } + } + if (state.isCompressed()) { + try { + decompress(state, timer); + } catch (RecordCoreException ex) { + throw ex.addLogInfo(LogMessageKeys.META_DATA_VERSION, metaData.getVersion()) + .addLogInfo(LogMessageKeys.PRIMARY_KEY, primaryKey); + } catch (DataFormatException ex) { + throw new RecordSerializationException("decompression error", ex) + .addLogInfo(LogMessageKeys.META_DATA_VERSION, metaData.getVersion()) + .addLogInfo(LogMessageKeys.PRIMARY_KEY, primaryKey); } - return inner.deserialize(metaData, primaryKey, state.getDataArray(), timer); } + return inner.deserialize(metaData, primaryKey, state.getDataArray(), timer); } @Nonnull diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializerJCE.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializerJCE.java index 650420f285..1370e8426c 100644 --- a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializerJCE.java +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializerJCE.java @@ -31,6 +31,7 @@ import java.security.GeneralSecurityException; import java.security.Key; import java.security.SecureRandom; +import java.util.Random; /** * An extension of {@link TransformedRecordSerializer} to use JCE to encrypt and decrypt records. @@ -38,41 +39,35 @@ */ @API(API.Status.UNSTABLE) public class TransformedRecordSerializerJCE extends TransformedRecordSerializer { - - @Nullable - protected final String cipherName; - @Nullable - protected final Key encryptionKey; @Nullable - protected final SecureRandom secureRandom; + protected final TransformedRecordSerializerKeyManager keyManager; protected TransformedRecordSerializerJCE(@Nonnull RecordSerializer inner, boolean compressWhenSerializing, int compressionLevel, boolean encryptWhenSerializing, double writeValidationRatio, - @Nullable String cipherName, - @Nullable Key encryptionKey, - @Nullable SecureRandom secureRandom) { + @Nullable TransformedRecordSerializerKeyManager keyManager) { super(inner, compressWhenSerializing, compressionLevel, encryptWhenSerializing, writeValidationRatio); - this.cipherName = cipherName; - this.encryptionKey = encryptionKey; - this.secureRandom = secureRandom; + this.keyManager = keyManager; } @Override - protected void encrypt(@Nonnull TransformState state, @Nullable StoreTimer timer) throws GeneralSecurityException { - if (cipherName == null || encryptionKey == null || secureRandom == null) { - throw new RecordSerializationException("attempted to encrypt without setting cipher name and key"); + protected void encrypt(@Nonnull TransformedRecordSerializerState state, @Nullable StoreTimer timer) throws GeneralSecurityException { + if (keyManager == null) { + throw new RecordSerializationException("attempted to encrypt without setting key manager (cipher name and key)"); } long startTime = System.nanoTime(); + int keyNumber = keyManager.getSerializationKey(); + state.setKeyNumber(keyNumber); + byte[] ivData = new byte[CipherPool.IV_SIZE]; - secureRandom.nextBytes(ivData); + keyManager.getRandom(keyNumber).nextBytes(ivData); IvParameterSpec iv = new IvParameterSpec(ivData); - Cipher cipher = CipherPool.borrowCipher(cipherName); + Cipher cipher = CipherPool.borrowCipher(keyManager.getCipher(keyNumber)); try { - cipher.init(Cipher.ENCRYPT_MODE, encryptionKey, iv); + cipher.init(Cipher.ENCRYPT_MODE, keyManager.getKey(keyNumber), iv); byte[] plainText = state.getDataArray(); byte[] cipherText = cipher.doFinal(plainText); @@ -81,7 +76,7 @@ protected void encrypt(@Nonnull TransformState state, @Nullable StoreTimer timer byte[] serialized = new byte[totalSize]; System.arraycopy(iv.getIV(), 0, serialized, 0, CipherPool.IV_SIZE); System.arraycopy(cipherText, 0, serialized, CipherPool.IV_SIZE, cipherText.length); - state.encrypted = true; + state.setEncrypted(true); state.setDataArray(serialized); } finally { CipherPool.returnCipher(cipher); @@ -92,21 +87,21 @@ protected void encrypt(@Nonnull TransformState state, @Nullable StoreTimer timer } @Override - protected void decrypt(@Nonnull TransformState state, @Nullable StoreTimer timer) throws GeneralSecurityException { - if (cipherName == null || encryptionKey == null || secureRandom == null) { + protected void decrypt(@Nonnull TransformedRecordSerializerState state, @Nullable StoreTimer timer) throws GeneralSecurityException { + if (keyManager == null) { throw new RecordSerializationException("missing encryption key or provider during decryption"); } long startTime = System.nanoTime(); byte[] ivData = new byte[CipherPool.IV_SIZE]; - System.arraycopy(state.data, state.offset, ivData, 0, CipherPool.IV_SIZE); + System.arraycopy(state.getData(), state.getOffset(), ivData, 0, CipherPool.IV_SIZE); IvParameterSpec iv = new IvParameterSpec(ivData); - byte[] cipherText = new byte[state.length - CipherPool.IV_SIZE]; - System.arraycopy(state.data, state.offset + CipherPool.IV_SIZE, cipherText, 0, cipherText.length); - Cipher cipher = CipherPool.borrowCipher(cipherName); + byte[] cipherText = new byte[state.getLength() - CipherPool.IV_SIZE]; + System.arraycopy(state.getData(), state.getOffset() + CipherPool.IV_SIZE, cipherText, 0, cipherText.length); + Cipher cipher = CipherPool.borrowCipher(keyManager.getCipher(state.getKeyNumber())); try { - cipher.init(Cipher.DECRYPT_MODE, encryptionKey, iv); + cipher.init(Cipher.DECRYPT_MODE, keyManager.getKey(state.getKeyNumber()), iv); byte[] plainText = cipher.doFinal(cipherText); state.setDataArray(plainText); @@ -118,6 +113,12 @@ protected void decrypt(@Nonnull TransformState state, @Nullable StoreTimer timer } } + @Nonnull + @Override + public RecordSerializer widen() { + return new TransformedRecordSerializerJCE<>(inner.widen(), compressWhenSerializing, compressionLevel, encryptWhenSerializing, writeValidationRatio, keyManager); + } + /** * Creates a new {@link Builder TransformedRecordSerializerJCE.Builder} instance * that is backed by the default serializer for {@link Message}s, namely @@ -155,6 +156,8 @@ public static Builder newBuilder(@Nonnull RecordSerialize * @param type of {@link Message} that underlying records will use */ public static class Builder extends TransformedRecordSerializer.Builder { + @Nullable + protected TransformedRecordSerializerKeyManager keyManager; @Nullable protected String cipherName; @Nullable @@ -272,6 +275,26 @@ public Builder clearSecureRandom() { return this; } + /** + * Sets the key manager used during cryptographic operations. + * @param keyManager key manager to use for encrypting and decrypting + * @return this Builder + */ + public Builder setKeyManager(@Nonnull TransformedRecordSerializerKeyManager keyManager) { + this.keyManager = keyManager; + return this; + } + + /** + * Clears a previously set key manager + * that might have been passed to this Builder. + * @return this Builder + */ + public Builder clearKeyManager() { + this.keyManager = null; + return this; + } + /** * Construct a {@link TransformedRecordSerializerJCE} from the * parameters specified by this builder. If one has enabled @@ -282,17 +305,18 @@ public Builder clearSecureRandom() { */ @Override public TransformedRecordSerializerJCE build() { - if (encryptWhenSerializing) { - if (encryptionKey == null) { + if (keyManager == null) { + if (encryptionKey != null) { + keyManager = new FixedZeroKeyManager(encryptionKey, cipherName, secureRandom); + } else if (encryptWhenSerializing) { throw new RecordCoreArgumentException("cannot encrypt when serializing if encryption key is not set"); } - } - if (encryptionKey != null) { - if (cipherName == null) { - cipherName = CipherPool.DEFAULT_CIPHER; + } else { + if (encryptionKey != null) { + throw new RecordCoreArgumentException("cannot specify both key manager and encryption key"); } - if (secureRandom == null) { - secureRandom = new SecureRandom(); + if (cipherName != null) { + throw new RecordCoreArgumentException("cannot specify both key manager and cipher name"); } } return new TransformedRecordSerializerJCE<>( @@ -301,10 +325,56 @@ public TransformedRecordSerializerJCE build() { compressionLevel, encryptWhenSerializing, writeValidationRatio, - cipherName, - encryptionKey, - secureRandom + keyManager ); } + + } + + static class FixedZeroKeyManager implements TransformedRecordSerializerKeyManager { + private final Key encryptionKey; + private final String cipherName; + private final SecureRandom secureRandom; + + public FixedZeroKeyManager(@Nonnull Key encryptionKey, @Nullable String cipherName, @Nullable SecureRandom secureRandom) { + if (cipherName == null) { + cipherName = CipherPool.DEFAULT_CIPHER; + } + if (secureRandom == null) { + secureRandom = new SecureRandom(); + } + this.encryptionKey = encryptionKey; + this.cipherName = cipherName; + this.secureRandom = secureRandom; + } + + @Override + public int getSerializationKey() { + return 0; + } + + @Override + public Key getKey(int keyNumber) { + if (keyNumber != 0) { + throw new RecordSerializationException("only provide key number 0"); + } + return encryptionKey; + } + + @Override + public String getCipher(int keyNumber) { + if (keyNumber != 0) { + throw new RecordSerializationException("only provide key number 0"); + } + return cipherName; + } + + @Override + public Random getRandom(int keyNumber) { + if (keyNumber != 0) { + throw new RecordSerializationException("only provide key number 0"); + } + return secureRandom; + } } } diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializerKeyManager.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializerKeyManager.java new file mode 100644 index 0000000000..b94b147b57 --- /dev/null +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializerKeyManager.java @@ -0,0 +1,63 @@ +/* + * TransformedRecordSerializerKeyManager.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2018 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.common; + +import com.apple.foundationdb.annotation.API; + +import java.security.Key; +import java.util.Random; + +/** + * An interface between {@link TransformedRecordSerializerJCE} and a source of keys with associated cipher algorithms. + * Each key is identified by a unique key number, which is persisted in serialized records so that the key + * can be recovered at deserialization time. + */ +@API(API.Status.EXPERIMENTAL) +public interface TransformedRecordSerializerKeyManager { + /** + * Get the key number to be used for serializing a record. + * Typically, this would be the latest key. + * @return the key number to use + */ + int getSerializationKey(); + + /** + * Get the key with the given key number. + * @param keyNumber the unique key identifier + * @return the cipher used with this key + */ + Key getKey(int keyNumber); + + /** + * Get the name of the cipher used with the given key number. + * @param keyNumber the unique key identifier + * @return the cipher used with this key + */ + String getCipher(int keyNumber); + + /** + * Get a random generator to fill IVs when encrypting. + * Normally this would be a {@link java.security.SecureRandom} and would not depend on the key. + */ + // TODO: Perhaps it would be better to have the KM give out an IvParameterSpec or something? + // Maybe wait until we have another algorithm that's different enough. + Random getRandom(int keyNumber); +} diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializerPrefix.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializerPrefix.java new file mode 100644 index 0000000000..8d85cb8aa8 --- /dev/null +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializerPrefix.java @@ -0,0 +1,188 @@ +/* + * TransformedRecordSerializerEncoding.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.common; + +import com.apple.foundationdb.annotation.API; +import com.apple.foundationdb.record.RecordMetaData; +import com.apple.foundationdb.record.logging.LogMessageKeys; +import com.apple.foundationdb.tuple.Tuple; + +import javax.annotation.Nonnull; + +/** + * Helper class for {@link TransformedRecordSerializer} giving the low-level bit encoding. + * + *

+ * The format is required to be compatible with various points in the history, which means it + * must read old data compatibly. Specifically, + *

+ *

+ * + *

+ * The encoded form begins with a Protobuf varint. + * The low three bits of this prefix specify how it was encoded and the remaining bits are the encryption key number. + *

+ * Recall that a serialized record is a message using the {@link RecordMetaData#getUnionDescriptor union descriptor} + * from the {@link RecordMetaData record metadata}. That means is will be wire-type 2 plus a field numnber + * in the union. Since field numbers must be positive, this is unambiguous versus just 2 for clear text. + * The remaining prefix types correspond to other Protobuf wire types: 1 is I64, 4 + * is EGROUP, and 5 is I32. None of these are possible for the start of a + * serialized message. Finally, a key number of zero becomes a single byte prefix of 1, 2, + * 4 or 5, formerly representing an encrypted and compressed bitmask. + *

+ * + *

+ * The encrypted form currently begins with a random IV, although this might change for other algorithms. + * The compressed form begins with a compression level, which again might someday be extended. + *

+ */ +@API(API.Status.UNSTABLE) +class TransformedRecordSerializerPrefix { + protected static final int PREFIX_ENCRYPTED = 1; + protected static final int PREFIX_CLEAR = 2; + protected static final int PREFIX_COMPRESSED = 4; + protected static final int PREFIX_COMPRESSED_THEN_ENCRYPTED = 5; + + protected static final int TYPE_MASK = 0x07; + protected static final int KEY_SHIFT = 3; + + @SuppressWarnings("fallthrough") + public static boolean decodePrefix(@Nonnull TransformedRecordSerializerState state, @Nonnull Tuple primaryKey) { + final long prefix = readVarint(state, primaryKey); + final int type = (int)(prefix & TYPE_MASK); + final long remaining = prefix >> KEY_SHIFT; + if (type == PREFIX_CLEAR && remaining != 0) { + return false; // Does not have a prefix + } + boolean valid = true; + switch (type) { + case PREFIX_CLEAR: + break; + case PREFIX_COMPRESSED_THEN_ENCRYPTED: + state.setEncrypted(true); + state.setCompressed(true); + break; + case PREFIX_ENCRYPTED: + state.setEncrypted(true); + break; + case PREFIX_COMPRESSED: + state.setCompressed(true); + break; + default: + valid = false; + break; + } + if (state.isEncrypted()) { + if (remaining < Integer.MIN_VALUE || remaining > Integer.MAX_VALUE) { + valid = false; + } else { + state.setKeyNumber((int)remaining); + } + } else if (remaining != 0) { + valid = false; + } + if (!valid) { + throw new RecordSerializationException("unrecognized transformation encoding") + .addLogInfo(LogMessageKeys.PRIMARY_KEY, primaryKey) + .addLogInfo("encoding", prefix); + } + return true; + } + + public static void encodePrefix(@Nonnull TransformedRecordSerializerState state) { + long prefix; + if (!state.isCompressed() && !state.isEncrypted()) { + prefix = PREFIX_CLEAR; + } else { + prefix = 0; + if (state.isCompressed()) { + prefix |= PREFIX_COMPRESSED; + } + if (state.isEncrypted()) { + prefix |= PREFIX_ENCRYPTED; + prefix |= (long)state.getKeyNumber() << KEY_SHIFT; + } + } + int size = state.getLength() + varintSize(prefix); + byte[] serialized = new byte[size]; + int offset = writeVarint(serialized, prefix); + System.arraycopy(state.getData(), state.getOffset(), serialized, offset, state.getLength()); + state.setDataArray(serialized); + } + + protected static int varintSize(long varint) { + int nbytes = 0; + do { + varint >>>= 7; + nbytes++; + } while (varint != 0); + return nbytes; + } + + protected static long readVarint(@Nonnull TransformedRecordSerializerState state, @Nonnull Tuple primaryKey) { + long varint = 0; + int nbytes = 0; + while (true) { + if (nbytes >= state.getLength()) { + throw new RecordSerializationException("transformation prefix malformed") + .addLogInfo(LogMessageKeys.PRIMARY_KEY, primaryKey); + } + final byte b = state.getData()[state.getOffset() + nbytes]; + if (nbytes == 9 && (b & 0xFE) != 0) { + // Continuing or more than just the 64th bit. + // This also detects random garbage with the sign bits on. + throw new RecordSerializationException("transformation prefix too long") + .addLogInfo(LogMessageKeys.PRIMARY_KEY, primaryKey); + } + varint |= (long)(b & 0x7F) << (nbytes * 7); + nbytes++; + if ((b & 0x80) == 0) { + break; + } + } + state.setOffset(state.getOffset() + nbytes); + state.setLength(state.getLength() - nbytes); + return varint; + } + + protected static int writeVarint(@Nonnull byte[] into, long varint) { + int nbytes = 0; + do { + byte b = (byte)(varint & 0x7F); + varint >>>= 7; + if (varint != 0) { + b |= (byte)0x80; + } + into[nbytes] = b; + nbytes++; + } while (varint != 0); + return nbytes; + } +} diff --git a/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializerState.java b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializerState.java new file mode 100644 index 0000000000..f5b17dbbe3 --- /dev/null +++ b/fdb-record-layer-core/src/main/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializerState.java @@ -0,0 +1,120 @@ +/* + * TransformedRecordSerializerState.java + * + * This source file is part of the FoundationDB open source project + * + * Copyright 2015-2025 Apple Inc. and the FoundationDB project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.apple.foundationdb.record.provider.common; + +import com.apple.foundationdb.annotation.SpotBugsSuppressWarnings; + +import javax.annotation.Nonnull; +import java.util.Arrays; + +/** + * The internal state of serialization / deserialization, pointing to a portion of a byte array. + * Also includes information on intended / found serialization format. + */ +@SpotBugsSuppressWarnings("EI_EXPOSE_REP") +class TransformedRecordSerializerState { + private boolean compressed; + private boolean encrypted; + private int keyNumber; + + @Nonnull + private byte[] data; + private int offset; + private int length; + + public TransformedRecordSerializerState(@Nonnull byte[] data) { + this(data, 0, data.length); + } + + public TransformedRecordSerializerState(@Nonnull byte[] data, int offset, int length) { + this.data = data; + this.offset = offset; + this.length = length; + } + + public boolean isCompressed() { + return compressed; + } + + public void setCompressed(boolean compressed) { + this.compressed = compressed; + } + + public boolean isEncrypted() { + return encrypted; + } + + public void setEncrypted(boolean encrypted) { + this.encrypted = encrypted; + } + + public int getKeyNumber() { + return keyNumber; + } + + public void setKeyNumber(int keyNumber) { + this.keyNumber = keyNumber; + } + + @Nonnull + public byte[] getData() { + return data; + } + + public int getOffset() { + return offset; + } + + public void setOffset(int offset) { + this.offset = offset; + } + + public int getLength() { + return length; + } + + public void setLength(int length) { + this.length = length; + } + + @Nonnull + public byte[] getDataArray() { + if (getOffset() == 0 && getLength() == getData().length) { + return getData(); + } else { + byte[] newData = Arrays.copyOfRange(getData(), getOffset(), getOffset() + getLength()); + offset = 0; + length = newData.length; + data = newData; + return newData; + } + } + + public void setDataArray(@Nonnull byte[] data) { + setDataArray(data, 0, data.length); + } + + public void setDataArray(@Nonnull byte[] data, int offset, int length) { + this.data = data; + this.offset = offset; + this.length = length; + } +} diff --git a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializerTest.java b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializerTest.java index 282a2c6bc0..4c7cf54de9 100644 --- a/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializerTest.java +++ b/fdb-record-layer-core/src/test/java/com/apple/foundationdb/record/provider/common/TransformedRecordSerializerTest.java @@ -29,6 +29,8 @@ import com.apple.foundationdb.record.logging.LogMessageKeys; import com.apple.foundationdb.record.metadata.RecordType; import com.apple.foundationdb.tuple.Tuple; +import com.apple.test.BooleanSource; +import com.apple.test.ParameterizedTestUtils; import com.google.common.base.Strings; import com.google.common.primitives.Bytes; import com.google.protobuf.Message; @@ -38,6 +40,7 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,8 +50,16 @@ import javax.crypto.SecretKey; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.security.InvalidKeyException; +import java.security.Key; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.Random; import java.util.stream.Stream; import java.util.zip.Deflater; @@ -148,7 +159,7 @@ public void noTransformations() { MySimpleRecord simpleRecord = MySimpleRecord.newBuilder().setRecNo(1066L).setStrValueIndexed("Hello").build(); RecordTypeUnion unionRecord = RecordTypeUnion.newBuilder().setMySimpleRecord(simpleRecord).build(); byte[] serialized = serialize(serializer, simpleRecord); - assertEquals(TransformedRecordSerializer.ENCODING_CLEAR, serialized[0]); + assertEquals(TransformedRecordSerializerPrefix.PREFIX_CLEAR, serialized[0]); assertArrayEquals(unionRecord.toByteArray(), Arrays.copyOfRange(serialized, 1, serialized.length)); logMetrics("metrics with no transformations"); @@ -173,7 +184,7 @@ public void compressSmallRecordWhenSerializing(@Nonnull final MySimpleRecord sma // There should be no compression actually added for a small record like this RecordTypeUnion smallUnionRecord = RecordTypeUnion.newBuilder().setMySimpleRecord(smallRecord).build(); byte[] serialized = serialize(serializer, smallRecord); - assertEquals(TransformedRecordSerializer.ENCODING_CLEAR, serialized[0]); + assertEquals(TransformedRecordSerializerPrefix.PREFIX_CLEAR, serialized[0]); assertArrayEquals(smallUnionRecord.toByteArray(), Arrays.copyOfRange(serialized, 1, serialized.length)); Message deserialized = deserialize(serializer, primaryKey, serialized); assertEquals(smallRecord, deserialized); @@ -204,7 +215,7 @@ public void compressLongRecordWhenSerializing(@Nonnull final MySimpleRecord long byte[] serialized = serialize(serializer, longRecord); assertThat(storeTimer.getCount(RecordSerializer.Counts.RECORD_BYTES_BEFORE_COMPRESSION), greaterThan(storeTimer.getCount(RecordSerializer.Counts.RECORD_BYTES_AFTER_COMPRESSION))); - assertEquals(TransformedRecordSerializer.ENCODING_COMPRESSED, serialized[0]); + assertEquals(TransformedRecordSerializerPrefix.PREFIX_COMPRESSED, serialized[0]); int rawLength = largeUnionRecord.toByteArray().length; assertEquals(rawLength, ByteBuffer.wrap(serialized, 2, 4).order(ByteOrder.BIG_ENDIAN).getInt()); Message deserialized = deserialize(serializer, primaryKey, serialized); @@ -338,7 +349,7 @@ public void buildWithoutSettingEncryption() { @Test public void decryptWithoutSettingEncryption() { - List codes = Arrays.asList(TransformedRecordSerializer.ENCODING_ENCRYPTED, TransformedRecordSerializer.ENCODING_ENCRYPTED | TransformedRecordSerializer.ENCODING_COMPRESSED); + List codes = Arrays.asList(TransformedRecordSerializerPrefix.PREFIX_ENCRYPTED, TransformedRecordSerializerPrefix.PREFIX_ENCRYPTED | TransformedRecordSerializerPrefix.PREFIX_COMPRESSED); for (int code : codes) { RecordSerializationException e = assertThrows(RecordSerializationException.class, () -> { TransformedRecordSerializer serializer = TransformedRecordSerializer.newDefaultBuilder().build(); @@ -359,24 +370,29 @@ public void unrecognizedEncoding() { deserialize(serializer, Tuple.from(1066L), serialized); }); assertThat(e.getMessage(), containsString("unrecognized transformation encoding")); - assertEquals(15, e.getLogInfo().get("encoding")); + assertEquals(15L, e.getLogInfo().get("encoding")); } - @Test - public void encryptWhenSerializing() throws Exception { + @ParameterizedTest + @BooleanSource + public void encryptWhenSerializing(boolean compressToo) throws Exception { KeyGenerator keyGen = KeyGenerator.getInstance("AES"); keyGen.init(128); SecretKey key = keyGen.generateKey(); TransformedRecordSerializer serializer = TransformedRecordSerializerJCE.newDefaultBuilder() .setEncryptWhenSerializing(true) .setEncryptionKey(key) + .setCompressWhenSerializing(compressToo) + .setCompressionLevel(9) .setWriteValidationRatio(1.0) .build(); MySimpleRecord mediumRecord = MySimpleRecord.newBuilder().setRecNo(1066L).setStrValueIndexed(SONNET_108).build(); assertTrue(Bytes.indexOf(mediumRecord.toByteArray(), "brain".getBytes()) >= 0, "should contain clear text"); byte[] serialized = serialize(serializer, mediumRecord); - assertEquals(TransformedRecordSerializer.ENCODING_ENCRYPTED, serialized[0]); + assertEquals(compressToo ? TransformedRecordSerializerPrefix.PREFIX_COMPRESSED_THEN_ENCRYPTED + : TransformedRecordSerializerPrefix.PREFIX_ENCRYPTED, + serialized[0]); assertFalse(Bytes.indexOf(serialized, "brain".getBytes()) >= 0, "should not contain clear text"); Message deserialized = deserialize(serializer, Tuple.from(1066L), serialized); assertEquals(mediumRecord, deserialized); @@ -450,10 +466,283 @@ public void corruptAnyBit() { } } + @ParameterizedTest + @ValueSource(ints = {6, 10}) + public void malformedVarintEncoding(int length) { + RecordSerializationException e = assertThrows(RecordSerializationException.class, () -> { + TransformedRecordSerializer serializer = TransformedRecordSerializer.newDefaultBuilder().build(); + byte[] serialized = new byte[length]; + Arrays.fill(serialized, (byte)0xFF); + deserialize(serializer, Tuple.from(1066L), serialized); + }); + assertThat(e.getMessage(), containsString(length > 64 / 7 ? "transformation prefix too long" + : "transformation prefix malformed")); + } + + @Test + public void invalidKeyNumberEncoding() { + RecordSerializationException e = assertThrows(RecordSerializationException.class, () -> { + TransformedRecordSerializer serializer = TransformedRecordSerializer.newDefaultBuilder().build(); + byte[] serialized = new byte[10]; + TransformedRecordSerializerPrefix.writeVarint(serialized, + TransformedRecordSerializerPrefix.PREFIX_ENCRYPTED + ((long)Integer.MAX_VALUE + 1 << 3)); + deserialize(serializer, Tuple.from(1066L), serialized); + }); + assertThat(e.getMessage(), containsString("unrecognized transformation encoding")); + } + + @Test + public void encryptRollingKeys() throws Exception { + RollingKeyManager keyManager = new RollingKeyManager(); + TransformedRecordSerializer serializer = TransformedRecordSerializerJCE.newDefaultBuilder() + .setEncryptWhenSerializing(true) + .setKeyManager(keyManager) + .setWriteValidationRatio(1.0) + .build(); + + List records = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + records.add(MySimpleRecord.newBuilder() + .setRecNo(1000 + i) + .setNumValue2(i) + .setStrValueIndexed(SONNET_108) + .build()); + } + + List serialized = new ArrayList<>(); + for (MySimpleRecord record : records) { + serialized.add(serialize(serializer, record)); + } + + assertThat(keyManager.numberOfKeys(), greaterThan(5)); + + List deserialized = new ArrayList<>(); + for (int i = 0; i < serialized.size(); i++) { + deserialized.add(deserialize(serializer, Tuple.from(1000L + i), serialized.get(i))); + } + + assertEquals(records, deserialized); + } + + @Test + public void cannotDecryptUnknownKey() throws Exception { + KeyGenerator keyGen = KeyGenerator.getInstance("AES"); + keyGen.init(128); + SecretKey key = keyGen.generateKey(); + SecureRandom random = new SecureRandom(); + TransformedRecordSerializer serializer = TransformedRecordSerializerJCE.newDefaultBuilder() + .setEncryptWhenSerializing(true) + .setKeyManager(new TransformedRecordSerializerKeyManager() { + @Override + public int getSerializationKey() { + return 2; + } + + @Override + public Key getKey(final int keyNumber) { + return key; + } + + @Override + public String getCipher(final int keyNumber) { + return CipherPool.DEFAULT_CIPHER; + } + + @Override + public Random getRandom(final int keyNumber) { + return random; + } + }) + .setWriteValidationRatio(1.0) + .build(); + + MySimpleRecord simpleRecord = MySimpleRecord.newBuilder().setRecNo(1066L).setStrValueIndexed("Hello").build(); + RecordTypeUnion unionRecord = RecordTypeUnion.newBuilder().setMySimpleRecord(simpleRecord).build(); + byte[] serialized = serialize(serializer, simpleRecord); + TransformedRecordSerializer deserializer = TransformedRecordSerializerJCE.newDefaultBuilder() + .setEncryptionKey(key) + .build(); + RecordSerializationException e = assertThrows(RecordSerializationException.class, + () -> deserialize(deserializer, Tuple.from(1066L), serialized)); + assertThat(e.getMessage(), containsString("only provide key number 0")); + } + + @ParameterizedTest + @BooleanSource + public void cannotDecryptWithoutKey(boolean jce) throws Exception { + KeyGenerator keyGen = KeyGenerator.getInstance("AES"); + keyGen.init(128); + TransformedRecordSerializer serializer = TransformedRecordSerializerJCE.newDefaultBuilder() + .setEncryptWhenSerializing(true) + .setEncryptionKey(keyGen.generateKey()) + .setWriteValidationRatio(1.0) + .build(); + MySimpleRecord simpleRecord = MySimpleRecord.newBuilder().setRecNo(1066L).setStrValueIndexed("Hello").build(); + RecordTypeUnion unionRecord = RecordTypeUnion.newBuilder().setMySimpleRecord(simpleRecord).build(); + byte[] serialized = serialize(serializer, simpleRecord); + TransformedRecordSerializer deserializer; + if (jce) { + deserializer = TransformedRecordSerializerJCE.newDefaultBuilder() + .setWriteValidationRatio(1.0) + .build(); + } else { + deserializer = TransformedRecordSerializer.newDefaultBuilder() + .setWriteValidationRatio(1.0) + .build(); + } + RecordSerializationException e = assertThrows(RecordSerializationException.class, + () -> deserialize(deserializer, Tuple.from(1066L), serialized)); + assertThat(e.getMessage(), containsString(jce ? "missing encryption key or provider during decryption" + : "this serializer cannot decrypt")); + } + + @Test + public void keyDoesNotMatchAlgorithm() throws Exception { + KeyGenerator keyGen = KeyGenerator.getInstance("DES"); + keyGen.init(56); + try { + TransformedRecordSerializer serializer = TransformedRecordSerializerJCE.newDefaultBuilder() + .setEncryptWhenSerializing(true) + .setEncryptionKey(keyGen.generateKey()) + .setWriteValidationRatio(1.0) + .build(); + MySimpleRecord simpleRecord = MySimpleRecord.newBuilder().setRecNo(1066L).setStrValueIndexed("Hello").build(); + RecordTypeUnion unionRecord = RecordTypeUnion.newBuilder().setMySimpleRecord(simpleRecord).build(); + RecordSerializationException e = assertThrows(RecordSerializationException.class, + () -> serialize(serializer, simpleRecord)); + assertThat(e.getMessage(), containsString("encryption error")); + assertThat(e.getCause(), instanceOf(InvalidKeyException.class)); + assertThat(e.getCause().getMessage(), containsString("Wrong algorithm")); + } finally { + // We have put something inconsistent in. + CipherPool.invalidateAll(); + } + } + + @Test + public void changeAlgorithm() throws Exception { + KeyGenerator keyGen = KeyGenerator.getInstance("AES"); + keyGen.init(128); + TransformedRecordSerializer serializer = TransformedRecordSerializerJCE.newDefaultBuilder() + .setEncryptWhenSerializing(true) + .setEncryptionKey(keyGen.generateKey()) + .setWriteValidationRatio(1.0) + .build(); + MySimpleRecord simpleRecord = MySimpleRecord.newBuilder().setRecNo(1066L).setStrValueIndexed("Hello").build(); + RecordTypeUnion unionRecord = RecordTypeUnion.newBuilder().setMySimpleRecord(simpleRecord).build(); + byte[] serialized = serialize(serializer, simpleRecord); + KeyGenerator keyGen2 = KeyGenerator.getInstance("DES"); + keyGen2.init(56); + TransformedRecordSerializer deserializer = TransformedRecordSerializerJCE.newDefaultBuilder() + .setEncryptWhenSerializing(true) + .setCipherName("DES") + .setEncryptionKey(keyGen2.generateKey()) + .setWriteValidationRatio(1.0) + .build(); + RecordSerializationException e = assertThrows(RecordSerializationException.class, + () -> deserialize(deserializer, Tuple.from(1066L), serialized)); + assertThat(e.getMessage(), containsString("decryption error")); + } + + public static Stream compressedAndOrEncrypted() { + return ParameterizedTestUtils.cartesianProduct( + ParameterizedTestUtils.booleans("compressed"), + ParameterizedTestUtils.booleans("encrypted")); + } + + @ParameterizedTest + @MethodSource("compressedAndOrEncrypted") + public void typed(boolean compressed, boolean encrypted) throws Exception { + RecordSerializer typedSerializer = new TypedRecordSerializer<>( + TestRecords1Proto.RecordTypeUnion.getDescriptor().findFieldByNumber(TestRecords1Proto.RecordTypeUnion._MYSIMPLERECORD_FIELD_NUMBER), + TestRecords1Proto.RecordTypeUnion::newBuilder, + TestRecords1Proto.RecordTypeUnion::hasMySimpleRecord, + TestRecords1Proto.RecordTypeUnion::getMySimpleRecord, + TestRecords1Proto.RecordTypeUnion.Builder::setMySimpleRecord); + MySimpleRecord record = MySimpleRecord.newBuilder().setRecNo(1066L).setStrValueIndexed(SONNET_108).build(); + + if (encrypted) { + KeyGenerator keyGen = KeyGenerator.getInstance("AES"); + keyGen.init(128); + SecretKey key = keyGen.generateKey(); + typedSerializer = TransformedRecordSerializerJCE.newBuilder(typedSerializer) + .setEncryptWhenSerializing(true) + .setEncryptionKey(key) + .setCompressWhenSerializing(compressed) + .setWriteValidationRatio(1.0) + .build(); + } else if (compressed) { + typedSerializer = TransformedRecordSerializer.newBuilder(typedSerializer) + .setCompressWhenSerializing(true) + .setWriteValidationRatio(1.0) + .build(); + } + + byte[] typedSerialized = serialize(typedSerializer, record); + RecordSerializer untypedSerializer = typedSerializer.widen(); + byte[] untypedSerialized = serialize(untypedSerializer, record); + + MySimpleRecord typedDeserialized = deserialize(typedSerializer, Tuple.from(1066L), typedSerialized); + assertEquals(record, typedDeserialized); + typedDeserialized = deserialize(typedSerializer, Tuple.from(1066L), untypedSerialized); + assertEquals(record, typedDeserialized); + + Message untypedDeserialized = deserialize(untypedSerializer, Tuple.from(1066L), typedSerialized); + assertEquals(record, untypedDeserialized); + untypedDeserialized = deserialize(untypedSerializer, Tuple.from(1066L), untypedSerialized); + assertEquals(record, untypedDeserialized); + } + + @Test + public void defaultKeyManagerKey() throws Exception { + KeyGenerator keyGen = KeyGenerator.getInstance("AES"); + keyGen.init(128); + TransformedRecordSerializerJCE serializer = TransformedRecordSerializerJCE.newDefaultBuilder() + .setEncryptWhenSerializing(true) + .setEncryptionKey(keyGen.generateKey()) + .setWriteValidationRatio(1.0) + .build(); + TransformedRecordSerializerKeyManager keyManager = serializer.keyManager; + assertEquals(0, keyManager.getSerializationKey()); + + RecordSerializationException e = assertThrows(RecordSerializationException.class, + () -> keyManager.getKey(1)); + assertThat(e.getMessage(), containsString("only provide key number 0")); + + e = assertThrows(RecordSerializationException.class, + () -> keyManager.getCipher(1)); + assertThat(e.getMessage(), containsString("only provide key number 0")); + + e = assertThrows(RecordSerializationException.class, + () -> keyManager.getRandom(1)); + assertThat(e.getMessage(), containsString("only provide key number 0")); + } + + @Test + public void invalidKeyManagerBuilder() throws Exception { + TransformedRecordSerializerJCE.Builder builder = TransformedRecordSerializerJCE.newDefaultBuilder(); + builder.setEncryptWhenSerializing(true); + + RecordCoreArgumentException e = assertThrows(RecordCoreArgumentException.class, builder::build); + assertThat(e.getMessage(), containsString("cannot encrypt when serializing if encryption key is not set")); + + RollingKeyManager keyManager = new RollingKeyManager(); + builder.setKeyManager(keyManager); + + builder.setCipherName(CipherPool.DEFAULT_CIPHER); + e = assertThrows(RecordCoreArgumentException.class, builder::build); + assertThat(e.getMessage(), containsString("cannot specify both key manager and cipher name")); + + builder.clearEncryption(); + builder.setEncryptionKey(keyManager.getKey(keyManager.getSerializationKey())); + e = assertThrows(RecordCoreArgumentException.class, builder::build); + assertThat(e.getMessage(), containsString("cannot specify both key manager and encryption key")); + } + private boolean isCompressed(byte[] serialized) { byte headerByte = serialized[0]; - return (headerByte & TransformedRecordSerializer.ENCODING_PROTO_MESSAGE_FIELD) == 0 - && (headerByte & TransformedRecordSerializer.ENCODING_COMPRESSED) != 0; + return headerByte == TransformedRecordSerializerPrefix.PREFIX_COMPRESSED || + headerByte == TransformedRecordSerializerPrefix.PREFIX_COMPRESSED_THEN_ENCRYPTED; } private int getUncompressedSize(byte[] serialized) { @@ -505,4 +794,49 @@ public RecordSerializer widen() { throw new UnsupportedOperationException("cannot widen this serializer"); } } + + private static class RollingKeyManager implements TransformedRecordSerializerKeyManager { + private final KeyGenerator keyGenerator; + private final Map keys; + private final Random random; + + public RollingKeyManager() throws NoSuchAlgorithmException { + keyGenerator = KeyGenerator.getInstance("AES"); + keyGenerator.init(128); + keys = new HashMap<>(); + random = new SecureRandom(); + } + + @Override + public int getSerializationKey() { + int newKey = random.nextInt(); + if (!keys.containsKey(newKey)) { + keys.put(newKey, keyGenerator.generateKey()); + } + return newKey; + } + + @Override + public Key getKey(final int keyNumber) { + if (!keys.containsKey(keyNumber)) { + throw new RecordCoreArgumentException("invalid key number"); + } + return keys.get(keyNumber); + } + + @Override + public String getCipher(final int keyNumber) { + return CipherPool.DEFAULT_CIPHER; + } + + @Override + public Random getRandom(final int keyNumber) { + return random; + } + + public int numberOfKeys() { + return keys.size(); + } + } + }