diff --git a/src/core/functions/scalar/encrypt_vectorized.cpp b/src/core/functions/scalar/encrypt_vectorized.cpp index b9b9095..209ac42 100644 --- a/src/core/functions/scalar/encrypt_vectorized.cpp +++ b/src/core/functions/scalar/encrypt_vectorized.cpp @@ -164,7 +164,11 @@ void EncryptVectorizedFlat(T *input_vector, uint64_t size, ExpressionState &stat } template -void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, Vector &result, uint8_t vector_type) { +void EncryptVectorized(const UnifiedVectorFormat &input_data_u, uint64_t size, ExpressionState &state, Vector &result, uint8_t vector_type) { + + // get the actual data and selection vector + auto input_data = UnifiedVectorFormat::GetData(input_data_u); + auto input_data_sel = input_data_u.sel; // local state and key auto &lstate = VCryptFunctionLocalState::ResetAndGet(state); @@ -200,11 +204,13 @@ void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, V auto counter_vec_data = FlatVector::GetData(*counter_vec); auto cipher_vec_data = FlatVector::GetData(*cipher_vec); - // set nonce + // ----------Set Nonce ---------- + nonce_hi_data[0] = (static_cast(lstate.iv[0]) << 32) | lstate.iv[1]; nonce_lo_data[0] = lstate.iv[2]; - // result vector is a dict vector containing encrypted data + // ---------- Encrypt and store into a Dict Vector ---------- + auto &blob = children[4]; SelectionVector sel(size); blob->Slice(*blob, sel, size); @@ -234,11 +240,24 @@ void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, V lstate.batch_size_in_bytes = batch_size * sizeof(T); uint64_t plaintext_bytes; - // copy the input vector to the buffer for padding - memcpy(lstate.buffer_p, input_vector, size * sizeof(T)); + // We clear the buffer first to avoid leaking data + memset(lstate.buffer_p, 0, total_size); + + uint64_t offset = 0; + for (uint32_t i = 0; i < size; i++){ + if (!validity.RowIsValid(input_data_sel->get_index(i))) { + continue; + } + Store(input_data[input_data_sel->get_index(i)], lstate.buffer_p + offset); + offset += sizeof(T); + } + + // this only works for flat vectors + // memcpy(lstate.buffer_p, input_data, size * sizeof(T)); lstate.encryption_state->Process(lstate.buffer_p, total_size, lstate.buffer_p, total_size); auto index = 0; + uint64_t dict_index; auto batch_nr = 0; uint64_t buffer_offset; @@ -247,8 +266,9 @@ void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, V // copy the first 8 bytes of plaintext of each batch // TODO: fix for edge case; resulting bytes are less then 64 bits (=8 bytes) + // this is not trivial for dict vectors auto processed = batch_nr * BATCH_SIZE; - memcpy(&plaintext_bytes, &input_vector[processed], sizeof(uint64_t)); + memcpy(&plaintext_bytes, &input_data[input_data_sel->get_index(processed)], sizeof(uint64_t)); blob_child_data[batch_nr] = StringVector::EmptyString(blob_child, lstate.batch_size_in_bytes); @@ -258,9 +278,10 @@ void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, V // set index in selection vector for (uint32_t j = 0; j < batch_size; j++) { - if (!validity.RowIsValid(index)) { - continue; - } +// if (!validity.RowIsValid(index)) { +// fix validity later +// continue; +// } // set index of selection vector blob_sel.set_index(index, batch_nr); // cipher contains the (masked) position in the block @@ -295,11 +316,8 @@ uint32_t RoundUpToBlockSize(uint32_t num) { } template -void EncryptVectorizedVariable(T *input_vector, uint64_t size, ExpressionState &state, Vector &result, uint8_t vector_type) { - - // we also need to store the total size of the encrypted blob... - // maybe for strings its really impossible.. - // just do the first 64 bits for sz +void EncryptVectorizedVariable(const UnifiedVectorFormat &input_data_u, uint64_t size, ExpressionState &state, + Vector &result, uint8_t vector_type) { // Storage Layout // ---------------------------------------------------------------------------- @@ -307,6 +325,10 @@ void EncryptVectorizedVariable(T *input_vector, uint64_t size, ExpressionState & // BATCH_SIZE * 64 bytes is byte offset (could be truncated to 16 bits for small strings) // resulting bytes are total length of the encrypted data + // get the actual data and selection vector + auto input_data = UnifiedVectorFormat::GetData(input_data_u); + auto input_data_sel = input_data_u.sel; + // local and global vcrypt state auto &lstate = VCryptFunctionLocalState::ResetAndGet(state); auto vcrypt_state = VCryptBasicFun::GetVCryptState(state); @@ -374,6 +396,7 @@ void EncryptVectorizedVariable(T *input_vector, uint64_t size, ExpressionState & batch_size = size; } + uint64_t dict_index; for (uint32_t i = 0; i < batches; i++) { lstate.ResetIV(counter_init); @@ -394,7 +417,12 @@ void EncryptVectorizedVariable(T *input_vector, uint64_t size, ExpressionState & // loop through the batch to see if we have to reallocate the buffer for (uint32_t j = 0; j < batch_size; j++) { - val_size = input_vector[index].GetSize(); +// if (!validity.RowIsValid(input_data_sel->get_index(i))) { +// // Fix this later +// continue; +// } + dict_index = input_data_sel->get_index(index); + val_size = input_data[dict_index].GetSize(); current_offset += val_size; Store(current_offset, offset_buf_ptr); offset_buf_ptr += sizeof(uint64_t); @@ -411,12 +439,13 @@ void EncryptVectorizedVariable(T *input_vector, uint64_t size, ExpressionState & // loop again to store the actual values for (uint32_t j = 0; j < batch_size; j++) { - val_size = input_vector[index].GetSize(); - memcpy(batch_ptr, input_vector[index].GetDataWriteable(), val_size); + dict_index = input_data_sel->get_index(index); + val_size = input_data[dict_index].GetSize(); + memcpy(batch_ptr, input_data[dict_index].GetDataWriteable(), val_size); batch_ptr += val_size; blob_sel.set_index(index, i); - cipher_vec_data[index] = j; - counter_vec_data[index] = counter_init; + cipher_vec_data[dict_index] = j; + counter_vec_data[dict_index] = counter_init; index++; } @@ -454,46 +483,38 @@ static void EncryptDataVectorized(DataChunk &args, ExpressionState &state, auto size = args.size(); auto &input_vector = args.data[0]; - UnifiedVectorFormat vdata_input; - input_vector.ToUnifiedFormat(args.size(), vdata_input); + UnifiedVectorFormat input_data_u; + input_vector.ToUnifiedFormat(args.size(), input_data_u); switch (vector_type) { case LogicalTypeId::TINYINT: case LogicalTypeId::UTINYINT: - return EncryptVectorized((int8_t *)vdata_input.data, - size, state, result, uint8_t(vector_type)); + return EncryptVectorized(input_data_u, size, state, result, uint8_t(vector_type)); case LogicalTypeId::SMALLINT: case LogicalTypeId::USMALLINT: - return EncryptVectorized((int16_t *)vdata_input.data, - size, state, result, uint8_t(vector_type)); + return EncryptVectorized(input_data_u, size, state, result, uint8_t(vector_type)); case LogicalTypeId::INTEGER: case LogicalTypeId::DATE: - return EncryptVectorized((int32_t *)vdata_input.data, + return EncryptVectorized(input_data_u, size, state, result, uint8_t(vector_type)); case LogicalTypeId::UINTEGER: - return EncryptVectorized((uint32_t *)vdata_input.data, - size, state, result, uint8_t(vector_type)); + return EncryptVectorized(input_data_u, size, state, result, uint8_t(vector_type)); case LogicalTypeId::BIGINT: case LogicalTypeId::TIMESTAMP: - return EncryptVectorized((int64_t *)vdata_input.data, - size, state, result, uint8_t(vector_type)); + return EncryptVectorized(input_data_u, size, state, result, uint8_t(vector_type)); case LogicalTypeId::UBIGINT: - return EncryptVectorized((uint64_t *)vdata_input.data, - size, state, result, uint8_t(vector_type)); + return EncryptVectorized(input_data_u, size, state, result, uint8_t(vector_type)); case LogicalTypeId::FLOAT: - return EncryptVectorized((float *)vdata_input.data, - size, state, result, uint8_t(vector_type)); + return EncryptVectorized(input_data_u, size, state, result, uint8_t(vector_type)); case LogicalTypeId::DOUBLE: - return EncryptVectorized((double *)vdata_input.data, - size, state, result, uint8_t(vector_type)); + return EncryptVectorized(input_data_u, size, state, result, uint8_t(vector_type)); case LogicalTypeId::VARCHAR: case LogicalTypeId::VARINT: case LogicalTypeId::CHAR: case LogicalTypeId::BLOB: case LogicalTypeId::MAP: case LogicalTypeId::LIST: - return EncryptVectorizedVariable((string_t *)vdata_input.data, - size, state, result, uint8_t(vector_type)); + return EncryptVectorizedVariable(input_data_u, size, state, result, uint8_t(vector_type)); default: throw NotImplementedException("Unsupported type for Encryption"); } diff --git a/src/core/types.cpp b/src/core/types.cpp index 828006c..1a935a6 100644 --- a/src/core/types.cpp +++ b/src/core/types.cpp @@ -104,6 +104,33 @@ LogicalType EncryptionTypes::GetEncryptionType(LogicalTypeId ltype) { } } +string EncryptionTypes::ToString(LogicalTypeId ltype) { + switch (ltype) { + case LogicalType::INTEGER: + return "E_INTEGER"; + case LogicalType::UINTEGER: + return "E_UINTEGER"; + case LogicalType::BIGINT: + return "E_BIGINT"; + case LogicalType::UBIGINT: + return "E_UBIGINT"; + case LogicalType::VARCHAR: + return "E_VARCHAR"; + case LogicalTypeId::DATE: + return "E_DATE"; + case LogicalTypeId::TIMESTAMP: + return "E_TIMESTAMP"; + case LogicalTypeId::CHAR: + return "E_CHAR"; + case LogicalTypeId::FLOAT: + return "E_FLOAT"; + case LogicalTypeId::DOUBLE: + return "E_DOUBLE"; + default: + throw InternalException("LogicalType not convertible to Encrypted type"); + } +} + // basic encrypted type // todo; we can just use one encrypted type, and just emplace the original type in the type modifiers... // the encrypted type just then needs an input (the original type) diff --git a/src/include/vcrypt/core/types.hpp b/src/include/vcrypt/core/types.hpp index 0b179db..7783c57 100644 --- a/src/include/vcrypt/core/types.hpp +++ b/src/include/vcrypt/core/types.hpp @@ -8,15 +8,12 @@ namespace core { struct EncryptionTypes { static LogicalType E_INTEGER(); - static LogicalType EA_INTEGER(); static LogicalType E_UINTEGER(); - static LogicalType EA_UINTEGER(); static LogicalType E_BIGINT(); static LogicalType E_UBIGINT(); static LogicalType E_VARCHAR(); static LogicalType E_DATE(); static LogicalType E_TIMESTAMP(); - static LogicalType E_DECIMAL(); static LogicalType E_FLOAT(); static LogicalType E_DOUBLE(); static LogicalType E_CHAR(); @@ -25,6 +22,7 @@ struct EncryptionTypes { static void Register(DatabaseInstance &db); static LogicalType GetBasicEncryptedType(); static LogicalType GetEncryptionType(LogicalTypeId ltype); + static string ToString(LogicalTypeId ltype); static vector IsAvailable(); static LogicalType GetOriginalType(EncryptedType etype); static EncryptedType GetEncryptedType(LogicalTypeId ltype);