Skip to content

Commit 8d269f5

Browse files
author
Lotte Felius
authored
Merge pull request #100 from ccfelius/demo_fixes
fixing bug in encryption
2 parents 6403ed9 + c08730f commit 8d269f5

File tree

3 files changed

+87
-41
lines changed

3 files changed

+87
-41
lines changed

src/core/functions/scalar/encrypt_vectorized.cpp

Lines changed: 59 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,11 @@ void EncryptVectorizedFlat(T *input_vector, uint64_t size, ExpressionState &stat
164164
}
165165

166166
template <typename T>
167-
void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, Vector &result, uint8_t vector_type) {
167+
void EncryptVectorized(const UnifiedVectorFormat &input_data_u, uint64_t size, ExpressionState &state, Vector &result, uint8_t vector_type) {
168+
169+
// get the actual data and selection vector
170+
auto input_data = UnifiedVectorFormat::GetData<T>(input_data_u);
171+
auto input_data_sel = input_data_u.sel;
168172

169173
// local state and key
170174
auto &lstate = VCryptFunctionLocalState::ResetAndGet(state);
@@ -200,11 +204,13 @@ void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, V
200204
auto counter_vec_data = FlatVector::GetData<uint32_t>(*counter_vec);
201205
auto cipher_vec_data = FlatVector::GetData<uint16_t>(*cipher_vec);
202206

203-
// set nonce
207+
// ----------Set Nonce ----------
208+
204209
nonce_hi_data[0] = (static_cast<uint64_t>(lstate.iv[0]) << 32) | lstate.iv[1];
205210
nonce_lo_data[0] = lstate.iv[2];
206211

207-
// result vector is a dict vector containing encrypted data
212+
// ---------- Encrypt and store into a Dict Vector ----------
213+
208214
auto &blob = children[4];
209215
SelectionVector sel(size);
210216
blob->Slice(*blob, sel, size);
@@ -234,11 +240,24 @@ void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, V
234240
lstate.batch_size_in_bytes = batch_size * sizeof(T);
235241
uint64_t plaintext_bytes;
236242

237-
// copy the input vector to the buffer for padding
238-
memcpy(lstate.buffer_p, input_vector, size * sizeof(T));
243+
// We clear the buffer first to avoid leaking data
244+
memset(lstate.buffer_p, 0, total_size);
245+
246+
uint64_t offset = 0;
247+
for (uint32_t i = 0; i < size; i++){
248+
if (!validity.RowIsValid(input_data_sel->get_index(i))) {
249+
continue;
250+
}
251+
Store<T>(input_data[input_data_sel->get_index(i)], lstate.buffer_p + offset);
252+
offset += sizeof(T);
253+
}
254+
255+
// this only works for flat vectors
256+
// memcpy(lstate.buffer_p, input_data, size * sizeof(T));
239257
lstate.encryption_state->Process(lstate.buffer_p, total_size, lstate.buffer_p, total_size);
240258

241259
auto index = 0;
260+
uint64_t dict_index;
242261
auto batch_nr = 0;
243262
uint64_t buffer_offset;
244263

@@ -247,8 +266,9 @@ void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, V
247266

248267
// copy the first 8 bytes of plaintext of each batch
249268
// TODO: fix for edge case; resulting bytes are less then 64 bits (=8 bytes)
269+
// this is not trivial for dict vectors
250270
auto processed = batch_nr * BATCH_SIZE;
251-
memcpy(&plaintext_bytes, &input_vector[processed], sizeof(uint64_t));
271+
memcpy(&plaintext_bytes, &input_data[input_data_sel->get_index(processed)], sizeof(uint64_t));
252272

253273
blob_child_data[batch_nr] =
254274
StringVector::EmptyString(blob_child, lstate.batch_size_in_bytes);
@@ -258,9 +278,10 @@ void EncryptVectorized(T *input_vector, uint64_t size, ExpressionState &state, V
258278

259279
// set index in selection vector
260280
for (uint32_t j = 0; j < batch_size; j++) {
261-
if (!validity.RowIsValid(index)) {
262-
continue;
263-
}
281+
// if (!validity.RowIsValid(index)) {
282+
// fix validity later
283+
// continue;
284+
// }
264285
// set index of selection vector
265286
blob_sel.set_index(index, batch_nr);
266287
// cipher contains the (masked) position in the block
@@ -295,18 +316,19 @@ uint32_t RoundUpToBlockSize(uint32_t num) {
295316
}
296317

297318
template <typename T>
298-
void EncryptVectorizedVariable(T *input_vector, uint64_t size, ExpressionState &state, Vector &result, uint8_t vector_type) {
299-
300-
// we also need to store the total size of the encrypted blob...
301-
// maybe for strings its really impossible..
302-
// just do the first 64 bits for sz
319+
void EncryptVectorizedVariable(const UnifiedVectorFormat &input_data_u, uint64_t size, ExpressionState &state,
320+
Vector &result, uint8_t vector_type) {
303321

304322
// Storage Layout
305323
// ----------------------------------------------------------------------------
306324
// 8 bytes VCrypt version
307325
// BATCH_SIZE * 64 bytes is byte offset (could be truncated to 16 bits for small strings)
308326
// resulting bytes are total length of the encrypted data
309327

328+
// get the actual data and selection vector
329+
auto input_data = UnifiedVectorFormat::GetData<T>(input_data_u);
330+
auto input_data_sel = input_data_u.sel;
331+
310332
// local and global vcrypt state
311333
auto &lstate = VCryptFunctionLocalState::ResetAndGet(state);
312334
auto vcrypt_state = VCryptBasicFun::GetVCryptState(state);
@@ -374,6 +396,7 @@ void EncryptVectorizedVariable(T *input_vector, uint64_t size, ExpressionState &
374396
batch_size = size;
375397
}
376398

399+
uint64_t dict_index;
377400
for (uint32_t i = 0; i < batches; i++) {
378401
lstate.ResetIV<T>(counter_init);
379402

@@ -394,7 +417,12 @@ void EncryptVectorizedVariable(T *input_vector, uint64_t size, ExpressionState &
394417

395418
// loop through the batch to see if we have to reallocate the buffer
396419
for (uint32_t j = 0; j < batch_size; j++) {
397-
val_size = input_vector[index].GetSize();
420+
// if (!validity.RowIsValid(input_data_sel->get_index(i))) {
421+
// // Fix this later
422+
// continue;
423+
// }
424+
dict_index = input_data_sel->get_index(index);
425+
val_size = input_data[dict_index].GetSize();
398426
current_offset += val_size;
399427
Store<uint64_t>(current_offset, offset_buf_ptr);
400428
offset_buf_ptr += sizeof(uint64_t);
@@ -411,12 +439,13 @@ void EncryptVectorizedVariable(T *input_vector, uint64_t size, ExpressionState &
411439

412440
// loop again to store the actual values
413441
for (uint32_t j = 0; j < batch_size; j++) {
414-
val_size = input_vector[index].GetSize();
415-
memcpy(batch_ptr, input_vector[index].GetDataWriteable(), val_size);
442+
dict_index = input_data_sel->get_index(index);
443+
val_size = input_data[dict_index].GetSize();
444+
memcpy(batch_ptr, input_data[dict_index].GetDataWriteable(), val_size);
416445
batch_ptr += val_size;
417446
blob_sel.set_index(index, i);
418-
cipher_vec_data[index] = j;
419-
counter_vec_data[index] = counter_init;
447+
cipher_vec_data[dict_index] = j;
448+
counter_vec_data[dict_index] = counter_init;
420449
index++;
421450
}
422451

@@ -454,46 +483,38 @@ static void EncryptDataVectorized(DataChunk &args, ExpressionState &state,
454483
auto size = args.size();
455484

456485
auto &input_vector = args.data[0];
457-
UnifiedVectorFormat vdata_input;
458-
input_vector.ToUnifiedFormat(args.size(), vdata_input);
486+
UnifiedVectorFormat input_data_u;
487+
input_vector.ToUnifiedFormat(args.size(), input_data_u);
459488

460489
switch (vector_type) {
461490
case LogicalTypeId::TINYINT:
462491
case LogicalTypeId::UTINYINT:
463-
return EncryptVectorized<int8_t>((int8_t *)vdata_input.data,
464-
size, state, result, uint8_t(vector_type));
492+
return EncryptVectorized<int8_t>(input_data_u, size, state, result, uint8_t(vector_type));
465493
case LogicalTypeId::SMALLINT:
466494
case LogicalTypeId::USMALLINT:
467-
return EncryptVectorized<int16_t>((int16_t *)vdata_input.data,
468-
size, state, result, uint8_t(vector_type));
495+
return EncryptVectorized<int16_t>(input_data_u, size, state, result, uint8_t(vector_type));
469496
case LogicalTypeId::INTEGER:
470497
case LogicalTypeId::DATE:
471-
return EncryptVectorized<int32_t>((int32_t *)vdata_input.data,
498+
return EncryptVectorized<int32_t>(input_data_u,
472499
size, state, result, uint8_t(vector_type));
473500
case LogicalTypeId::UINTEGER:
474-
return EncryptVectorized<uint32_t>((uint32_t *)vdata_input.data,
475-
size, state, result, uint8_t(vector_type));
501+
return EncryptVectorized<uint32_t>(input_data_u, size, state, result, uint8_t(vector_type));
476502
case LogicalTypeId::BIGINT:
477503
case LogicalTypeId::TIMESTAMP:
478-
return EncryptVectorized<int64_t>((int64_t *)vdata_input.data,
479-
size, state, result, uint8_t(vector_type));
504+
return EncryptVectorized<int64_t>(input_data_u, size, state, result, uint8_t(vector_type));
480505
case LogicalTypeId::UBIGINT:
481-
return EncryptVectorized<uint64_t>((uint64_t *)vdata_input.data,
482-
size, state, result, uint8_t(vector_type));
506+
return EncryptVectorized<uint64_t>(input_data_u, size, state, result, uint8_t(vector_type));
483507
case LogicalTypeId::FLOAT:
484-
return EncryptVectorized<float>((float *)vdata_input.data,
485-
size, state, result, uint8_t(vector_type));
508+
return EncryptVectorized<float>(input_data_u, size, state, result, uint8_t(vector_type));
486509
case LogicalTypeId::DOUBLE:
487-
return EncryptVectorized<double>((double *)vdata_input.data,
488-
size, state, result, uint8_t(vector_type));
510+
return EncryptVectorized<double>(input_data_u, size, state, result, uint8_t(vector_type));
489511
case LogicalTypeId::VARCHAR:
490512
case LogicalTypeId::VARINT:
491513
case LogicalTypeId::CHAR:
492514
case LogicalTypeId::BLOB:
493515
case LogicalTypeId::MAP:
494516
case LogicalTypeId::LIST:
495-
return EncryptVectorizedVariable<string_t>((string_t *)vdata_input.data,
496-
size, state, result, uint8_t(vector_type));
517+
return EncryptVectorizedVariable<string_t>(input_data_u, size, state, result, uint8_t(vector_type));
497518
default:
498519
throw NotImplementedException("Unsupported type for Encryption");
499520
}

src/core/types.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,33 @@ LogicalType EncryptionTypes::GetEncryptionType(LogicalTypeId ltype) {
104104
}
105105
}
106106

107+
string EncryptionTypes::ToString(LogicalTypeId ltype) {
108+
switch (ltype) {
109+
case LogicalType::INTEGER:
110+
return "E_INTEGER";
111+
case LogicalType::UINTEGER:
112+
return "E_UINTEGER";
113+
case LogicalType::BIGINT:
114+
return "E_BIGINT";
115+
case LogicalType::UBIGINT:
116+
return "E_UBIGINT";
117+
case LogicalType::VARCHAR:
118+
return "E_VARCHAR";
119+
case LogicalTypeId::DATE:
120+
return "E_DATE";
121+
case LogicalTypeId::TIMESTAMP:
122+
return "E_TIMESTAMP";
123+
case LogicalTypeId::CHAR:
124+
return "E_CHAR";
125+
case LogicalTypeId::FLOAT:
126+
return "E_FLOAT";
127+
case LogicalTypeId::DOUBLE:
128+
return "E_DOUBLE";
129+
default:
130+
throw InternalException("LogicalType not convertible to Encrypted type");
131+
}
132+
}
133+
107134
// basic encrypted type
108135
// todo; we can just use one encrypted type, and just emplace the original type in the type modifiers...
109136
// the encrypted type just then needs an input (the original type)

src/include/vcrypt/core/types.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,12 @@ namespace core {
88

99
struct EncryptionTypes {
1010
static LogicalType E_INTEGER();
11-
static LogicalType EA_INTEGER();
1211
static LogicalType E_UINTEGER();
13-
static LogicalType EA_UINTEGER();
1412
static LogicalType E_BIGINT();
1513
static LogicalType E_UBIGINT();
1614
static LogicalType E_VARCHAR();
1715
static LogicalType E_DATE();
1816
static LogicalType E_TIMESTAMP();
19-
static LogicalType E_DECIMAL();
2017
static LogicalType E_FLOAT();
2118
static LogicalType E_DOUBLE();
2219
static LogicalType E_CHAR();
@@ -25,6 +22,7 @@ struct EncryptionTypes {
2522
static void Register(DatabaseInstance &db);
2623
static LogicalType GetBasicEncryptedType();
2724
static LogicalType GetEncryptionType(LogicalTypeId ltype);
25+
static string ToString(LogicalTypeId ltype);
2826
static vector<LogicalType> IsAvailable();
2927
static LogicalType GetOriginalType(EncryptedType etype);
3028
static EncryptedType GetEncryptedType(LogicalTypeId ltype);

0 commit comments

Comments
 (0)