Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 59 additions & 38 deletions src/core/functions/scalar/encrypt_vectorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ void EncryptVectorizedFlat(T *input_vector, uint64_t size, ExpressionState &stat
}

template <typename T>
void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, Vector &result, uint8_t vector_type) {
void EncryptVectorized(const UnifiedVectorFormat &input_data_u, uint64_t size, ExpressionState &state, Vector &result, uint8_t vector_type) {

// get the actual data and selection vector
auto input_data = UnifiedVectorFormat::GetData<T>(input_data_u);
auto input_data_sel = input_data_u.sel;

// local state and key
auto &lstate = VCryptFunctionLocalState::ResetAndGet(state);
Expand Down Expand Up @@ -200,11 +204,13 @@ void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, V
auto counter_vec_data = FlatVector::GetData<uint32_t>(*counter_vec);
auto cipher_vec_data = FlatVector::GetData<uint16_t>(*cipher_vec);

// set nonce
// ----------Set Nonce ----------

nonce_hi_data[0] = (static_cast<uint64_t>(lstate.iv[0]) << 32) | lstate.iv[1];
nonce_lo_data[0] = lstate.iv[2];

// result vector is a dict vector containing encrypted data
// ---------- Encrypt and store into a Dict Vector ----------

auto &blob = children[4];
SelectionVector sel(size);
blob->Slice(*blob, sel, size);
Expand Down Expand Up @@ -234,11 +240,24 @@ void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, V
lstate.batch_size_in_bytes = batch_size * sizeof(T);
uint64_t plaintext_bytes;

// copy the input vector to the buffer for padding
memcpy(lstate.buffer_p, input_vector, size * sizeof(T));
// We clear the buffer first to avoid leaking data
memset(lstate.buffer_p, 0, total_size);

uint64_t offset = 0;
for (uint32_t i = 0; i < size; i++){
if (!validity.RowIsValid(input_data_sel->get_index(i))) {
continue;
}
Store<T>(input_data[input_data_sel->get_index(i)], lstate.buffer_p + offset);
offset += sizeof(T);
}

// this only works for flat vectors
// memcpy(lstate.buffer_p, input_data, size * sizeof(T));
lstate.encryption_state->Process(lstate.buffer_p, total_size, lstate.buffer_p, total_size);

auto index = 0;
uint64_t dict_index;
auto batch_nr = 0;
uint64_t buffer_offset;

Expand All @@ -247,8 +266,9 @@ void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, V

// copy the first 8 bytes of plaintext of each batch
// TODO: fix for edge case; resulting bytes are less then 64 bits (=8 bytes)
// this is not trivial for dict vectors
auto processed = batch_nr * BATCH_SIZE;
memcpy(&plaintext_bytes, &input_vector[processed], sizeof(uint64_t));
memcpy(&plaintext_bytes, &input_data[input_data_sel->get_index(processed)], sizeof(uint64_t));

blob_child_data[batch_nr] =
StringVector::EmptyString(blob_child, lstate.batch_size_in_bytes);
Expand All @@ -258,9 +278,10 @@ void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, V

// set index in selection vector
for (uint32_t j = 0; j < batch_size; j++) {
if (!validity.RowIsValid(index)) {
continue;
}
// if (!validity.RowIsValid(index)) {
// fix validity later
// continue;
// }
// set index of selection vector
blob_sel.set_index(index, batch_nr);
// cipher contains the (masked) position in the block
Expand Down Expand Up @@ -295,18 +316,19 @@ uint32_t RoundUpToBlockSize(uint32_t num) {
}

template <typename T>
void EncryptVectorizedVariable(T *input_vector, uint64_t size, ExpressionState &state, Vector &result, uint8_t vector_type) {

// we also need to store the total size of the encrypted blob...
// maybe for strings its really impossible..
// just do the first 64 bits for sz
void EncryptVectorizedVariable(const UnifiedVectorFormat &input_data_u, uint64_t size, ExpressionState &state,
Vector &result, uint8_t vector_type) {

// Storage Layout
// ----------------------------------------------------------------------------
// 8 bytes VCrypt version
// BATCH_SIZE * 64 bytes is byte offset (could be truncated to 16 bits for small strings)
// resulting bytes are total length of the encrypted data

// get the actual data and selection vector
auto input_data = UnifiedVectorFormat::GetData<T>(input_data_u);
auto input_data_sel = input_data_u.sel;

// local and global vcrypt state
auto &lstate = VCryptFunctionLocalState::ResetAndGet(state);
auto vcrypt_state = VCryptBasicFun::GetVCryptState(state);
Expand Down Expand Up @@ -374,6 +396,7 @@ void EncryptVectorizedVariable(T *input_vector, uint64_t size, ExpressionState &
batch_size = size;
}

uint64_t dict_index;
for (uint32_t i = 0; i < batches; i++) {
lstate.ResetIV<T>(counter_init);

Expand All @@ -394,7 +417,12 @@ void EncryptVectorizedVariable(T *input_vector, uint64_t size, ExpressionState &

// loop through the batch to see if we have to reallocate the buffer
for (uint32_t j = 0; j < batch_size; j++) {
val_size = input_vector[index].GetSize();
// if (!validity.RowIsValid(input_data_sel->get_index(i))) {
// // Fix this later
// continue;
// }
dict_index = input_data_sel->get_index(index);
val_size = input_data[dict_index].GetSize();
current_offset += val_size;
Store<uint64_t>(current_offset, offset_buf_ptr);
offset_buf_ptr += sizeof(uint64_t);
Expand All @@ -411,12 +439,13 @@ void EncryptVectorizedVariable(T *input_vector, uint64_t size, ExpressionState &

// loop again to store the actual values
for (uint32_t j = 0; j < batch_size; j++) {
val_size = input_vector[index].GetSize();
memcpy(batch_ptr, input_vector[index].GetDataWriteable(), val_size);
dict_index = input_data_sel->get_index(index);
val_size = input_data[dict_index].GetSize();
memcpy(batch_ptr, input_data[dict_index].GetDataWriteable(), val_size);
batch_ptr += val_size;
blob_sel.set_index(index, i);
cipher_vec_data[index] = j;
counter_vec_data[index] = counter_init;
cipher_vec_data[dict_index] = j;
counter_vec_data[dict_index] = counter_init;
index++;
}

Expand Down Expand Up @@ -454,46 +483,38 @@ static void EncryptDataVectorized(DataChunk &args, ExpressionState &state,
auto size = args.size();

auto &input_vector = args.data[0];
UnifiedVectorFormat vdata_input;
input_vector.ToUnifiedFormat(args.size(), vdata_input);
UnifiedVectorFormat input_data_u;
input_vector.ToUnifiedFormat(args.size(), input_data_u);

switch (vector_type) {
case LogicalTypeId::TINYINT:
case LogicalTypeId::UTINYINT:
return EncryptVectorized<int8_t>((int8_t *)vdata_input.data,
size, state, result, uint8_t(vector_type));
return EncryptVectorized<int8_t>(input_data_u, size, state, result, uint8_t(vector_type));
case LogicalTypeId::SMALLINT:
case LogicalTypeId::USMALLINT:
return EncryptVectorized<int16_t>((int16_t *)vdata_input.data,
size, state, result, uint8_t(vector_type));
return EncryptVectorized<int16_t>(input_data_u, size, state, result, uint8_t(vector_type));
case LogicalTypeId::INTEGER:
case LogicalTypeId::DATE:
return EncryptVectorized<int32_t>((int32_t *)vdata_input.data,
return EncryptVectorized<int32_t>(input_data_u,
size, state, result, uint8_t(vector_type));
case LogicalTypeId::UINTEGER:
return EncryptVectorized<uint32_t>((uint32_t *)vdata_input.data,
size, state, result, uint8_t(vector_type));
return EncryptVectorized<uint32_t>(input_data_u, size, state, result, uint8_t(vector_type));
case LogicalTypeId::BIGINT:
case LogicalTypeId::TIMESTAMP:
return EncryptVectorized<int64_t>((int64_t *)vdata_input.data,
size, state, result, uint8_t(vector_type));
return EncryptVectorized<int64_t>(input_data_u, size, state, result, uint8_t(vector_type));
case LogicalTypeId::UBIGINT:
return EncryptVectorized<uint64_t>((uint64_t *)vdata_input.data,
size, state, result, uint8_t(vector_type));
return EncryptVectorized<uint64_t>(input_data_u, size, state, result, uint8_t(vector_type));
case LogicalTypeId::FLOAT:
return EncryptVectorized<float>((float *)vdata_input.data,
size, state, result, uint8_t(vector_type));
return EncryptVectorized<float>(input_data_u, size, state, result, uint8_t(vector_type));
case LogicalTypeId::DOUBLE:
return EncryptVectorized<double>((double *)vdata_input.data,
size, state, result, uint8_t(vector_type));
return EncryptVectorized<double>(input_data_u, size, state, result, uint8_t(vector_type));
case LogicalTypeId::VARCHAR:
case LogicalTypeId::VARINT:
case LogicalTypeId::CHAR:
case LogicalTypeId::BLOB:
case LogicalTypeId::MAP:
case LogicalTypeId::LIST:
return EncryptVectorizedVariable<string_t>((string_t *)vdata_input.data,
size, state, result, uint8_t(vector_type));
return EncryptVectorizedVariable<string_t>(input_data_u, size, state, result, uint8_t(vector_type));
default:
throw NotImplementedException("Unsupported type for Encryption");
}
Expand Down
27 changes: 27 additions & 0 deletions src/core/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,33 @@ LogicalType EncryptionTypes::GetEncryptionType(LogicalTypeId ltype) {
}
}

string EncryptionTypes::ToString(LogicalTypeId ltype) {
switch (ltype) {
case LogicalType::INTEGER:
return "E_INTEGER";
case LogicalType::UINTEGER:
return "E_UINTEGER";
case LogicalType::BIGINT:
return "E_BIGINT";
case LogicalType::UBIGINT:
return "E_UBIGINT";
case LogicalType::VARCHAR:
return "E_VARCHAR";
case LogicalTypeId::DATE:
return "E_DATE";
case LogicalTypeId::TIMESTAMP:
return "E_TIMESTAMP";
case LogicalTypeId::CHAR:
return "E_CHAR";
case LogicalTypeId::FLOAT:
return "E_FLOAT";
case LogicalTypeId::DOUBLE:
return "E_DOUBLE";
default:
throw InternalException("LogicalType not convertible to Encrypted type");
}
}

// basic encrypted type
// todo; we can just use one encrypted type, and just emplace the original type in the type modifiers...
// the encrypted type just then needs an input (the original type)
Expand Down
4 changes: 1 addition & 3 deletions src/include/vcrypt/core/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,12 @@ namespace core {

struct EncryptionTypes {
static LogicalType E_INTEGER();
static LogicalType EA_INTEGER();
static LogicalType E_UINTEGER();
static LogicalType EA_UINTEGER();
static LogicalType E_BIGINT();
static LogicalType E_UBIGINT();
static LogicalType E_VARCHAR();
static LogicalType E_DATE();
static LogicalType E_TIMESTAMP();
static LogicalType E_DECIMAL();
static LogicalType E_FLOAT();
static LogicalType E_DOUBLE();
static LogicalType E_CHAR();
Expand All @@ -25,6 +22,7 @@ struct EncryptionTypes {
static void Register(DatabaseInstance &db);
static LogicalType GetBasicEncryptedType();
static LogicalType GetEncryptionType(LogicalTypeId ltype);
static string ToString(LogicalTypeId ltype);
static vector<LogicalType> IsAvailable();
static LogicalType GetOriginalType(EncryptedType etype);
static EncryptedType GetEncryptedType(LogicalTypeId ltype);
Expand Down
Loading