Skip to content

Commit cfc17e4

Browse files
committed
fix: some memory cleanups
1 parent 58e38a3 commit cfc17e4

File tree

1 file changed

+40
-43
lines changed

1 file changed

+40
-43
lines changed

src/crypto_extension.cpp

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -302,29 +302,16 @@ namespace duckdb
302302

303303
struct HashAggregateState
304304
{
305-
// No need to free this is just a reference to the algorithm.
306-
blake3_hasher *blake3_hasher;
305+
bool is_touched;
306+
blake3_hasher blake3_hasher;
307307
EVP_MD_CTX *ctx;
308308

309309
HashAggregateState()
310310
{
311-
blake3_hasher = nullptr;
311+
is_touched = false;
312+
blake3_hasher_init(&blake3_hasher);
312313
ctx = nullptr;
313314
}
314-
315-
~HashAggregateState()
316-
{
317-
if (ctx)
318-
{
319-
EVP_MD_CTX_free(ctx);
320-
ctx = nullptr;
321-
}
322-
if (blake3_hasher)
323-
{
324-
delete blake3_hasher;
325-
blake3_hasher = nullptr;
326-
}
327-
}
328315
};
329316

330317
struct HashAggregateBindData : public FunctionData
@@ -354,8 +341,20 @@ namespace duckdb
354341
template <class STATE>
355342
static void Initialize(STATE &state)
356343
{
357-
state.blake3_hasher = nullptr;
344+
// So this may fill the state with random garbage
358345
state.ctx = nullptr;
346+
state.is_touched = false;
347+
blake3_hasher_init(&state.blake3_hasher);
348+
}
349+
350+
template <class STATE>
351+
static void Destroy(STATE &state, AggregateInputData &aggr_input_data)
352+
{
353+
if (state.ctx)
354+
{
355+
EVP_MD_CTX_free(state.ctx);
356+
state.ctx = nullptr;
357+
}
359358
}
360359

361360
static bool IgnoreNull()
@@ -371,56 +370,53 @@ namespace duckdb
371370

372371
if (bind_data.is_blake3)
373372
{
374-
if (!state.blake3_hasher)
373+
if (!state.is_touched)
375374
{
376-
state.blake3_hasher = new blake3_hasher;
377-
blake3_hasher_init(state.blake3_hasher);
375+
blake3_hasher_init(&state.blake3_hasher);
376+
state.is_touched = true;
378377
}
379378

380379
// hash the record length as well to prevent length extension attacks
381380
if constexpr (std::is_same_v<A_TYPE, string_t>)
382381
{
383382
const uint64_t size = a_data.GetSize();
384-
blake3_hasher_update(state.blake3_hasher, &size, sizeof(uint64_t));
385-
blake3_hasher_update(state.blake3_hasher, a_data.GetDataUnsafe(), size);
383+
blake3_hasher_update(&state.blake3_hasher, &size, sizeof(uint64_t));
384+
blake3_hasher_update(&state.blake3_hasher, a_data.GetDataUnsafe(), size);
386385
}
387386
else
388387
{
389-
blake3_hasher_update(state.blake3_hasher, &a_data, sizeof(a_data));
388+
blake3_hasher_update(&state.blake3_hasher, &a_data, sizeof(a_data));
390389
}
391390
}
392391
else
393392
{
394-
if (!state.ctx)
393+
if (!state.is_touched)
395394
{
396395
state.ctx = EVP_MD_CTX_new();
397396
if (EVP_DigestInit_ex(state.ctx, bind_data.md, nullptr) != 1)
398397
{
399-
EVP_MD_CTX_free(state.ctx);
400398
throw InternalException("Failed to initialize hash context");
401399
}
400+
state.is_touched = true;
402401
}
403402

404403
if constexpr (std::is_same_v<A_TYPE, string_t>)
405404
{
406405
const uint64_t size = a_data.GetSize();
407406
if (EVP_DigestUpdate(state.ctx, &size, sizeof(uint64_t)) != 1)
408407
{
409-
EVP_MD_CTX_free(state.ctx);
410408
throw InternalException("Failed to update hash");
411409
}
412410

413411
if (EVP_DigestUpdate(state.ctx, a_data.GetDataUnsafe(), size) != 1)
414412
{
415-
EVP_MD_CTX_free(state.ctx);
416413
throw InternalException("Failed to update hash");
417414
}
418415
}
419416
else
420417
{
421418
if (EVP_DigestUpdate(state.ctx, &a_data, sizeof(a_data)) != 1)
422419
{
423-
EVP_MD_CTX_free(state.ctx);
424420
throw InternalException("Failed to update hash");
425421
}
426422
}
@@ -441,13 +437,21 @@ namespace duckdb
441437
static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data)
442438
{
443439
auto &bind_data = aggr_input_data.bind_data->Cast<HashAggregateBindData>();
444-
const bool source_used = bind_data.is_blake3 ? (source.blake3_hasher != nullptr) : (source.ctx != nullptr);
445-
const bool target_used = bind_data.is_blake3 ? (target.blake3_hasher != nullptr) : (target.ctx != nullptr);
440+
const bool source_used = source.is_touched;
441+
const bool target_used = target.is_touched;
446442

447443
if (source_used && !target_used)
448444
{
449-
target.blake3_hasher = source.blake3_hasher;
450-
target.ctx = source.ctx;
445+
if (bind_data.is_blake3)
446+
{
447+
target.blake3_hasher = source.blake3_hasher;
448+
}
449+
else
450+
{
451+
target.ctx = EVP_MD_CTX_new();
452+
EVP_MD_CTX_copy_ex(target.ctx, source.ctx);
453+
}
454+
target.is_touched = true;
451455
}
452456
else if (!source_used)
453457
{
@@ -464,19 +468,15 @@ namespace duckdb
464468
auto &bind_data = finalize_data.input.bind_data->Cast<HashAggregateBindData>();
465469
if (bind_data.is_blake3)
466470
{
467-
if (!state.blake3_hasher)
471+
if (!state.is_touched)
468472
{
469-
delete state.blake3_hasher;
470-
state.blake3_hasher = nullptr;
471473
finalize_data.ReturnNull();
472474
return;
473475
}
474476
char output[BLAKE3_OUT_LEN];
475-
blake3_hasher_finalize(state.blake3_hasher, reinterpret_cast<uint8_t *>(&output), BLAKE3_OUT_LEN);
477+
blake3_hasher_finalize(&state.blake3_hasher, reinterpret_cast<uint8_t *>(&output), BLAKE3_OUT_LEN);
476478
target = StringVector::AddStringOrBlob(finalize_data.result, reinterpret_cast<const char *>(&output),
477479
BLAKE3_OUT_LEN);
478-
delete state.blake3_hasher;
479-
state.blake3_hasher = nullptr;
480480
}
481481
else
482482
{
@@ -491,11 +491,8 @@ namespace duckdb
491491
// Finalize the hash
492492
if (EVP_DigestFinal_ex(state.ctx, hash_result, &hash_len) != 1)
493493
{
494-
EVP_MD_CTX_free(state.ctx);
495494
throw InternalException("Failed to finalize hash");
496495
}
497-
EVP_MD_CTX_free(state.ctx);
498-
state.ctx = nullptr;
499496
target = StringVector::AddStringOrBlob(finalize_data.result, reinterpret_cast<const char *>(&hash_result),
500497
hash_len);
501498
}
@@ -549,7 +546,7 @@ namespace duckdb
549546
static void RegisterHashAggType(AggregateFunctionSet &agg_set, const LogicalType &logical_type)
550547
{
551548
auto agg_func =
552-
AggregateFunction::UnaryAggregate<HashAggregateState, CPP_TYPE, string_t, HashAggregateOperation<HashAggregateState>>(
549+
AggregateFunction::UnaryAggregateDestructor<HashAggregateState, CPP_TYPE, string_t, HashAggregateOperation<HashAggregateState>>(
553550
logical_type, LogicalType::BLOB);
554551
agg_func.order_dependent = AggregateOrderDependent::ORDER_DEPENDENT;
555552
agg_func.distinct_dependent = AggregateDistinctDependent::DISTINCT_DEPENDENT;

0 commit comments

Comments
 (0)