@@ -302,29 +302,16 @@ namespace duckdb
302302
303303 struct HashAggregateState
304304 {
305- // No need to free this is just a reference to the algorithm.
306- blake3_hasher * blake3_hasher;
305+ bool is_touched;
306+ blake3_hasher blake3_hasher;
307307 EVP_MD_CTX *ctx;
308308
309309 HashAggregateState ()
310310 {
311- blake3_hasher = nullptr ;
311+ is_touched = false ;
312+ blake3_hasher_init (&blake3_hasher);
312313 ctx = nullptr ;
313314 }
314-
315- ~HashAggregateState ()
316- {
317- if (ctx)
318- {
319- EVP_MD_CTX_free (ctx);
320- ctx = nullptr ;
321- }
322- if (blake3_hasher)
323- {
324- delete blake3_hasher;
325- blake3_hasher = nullptr ;
326- }
327- }
328315 };
329316
330317 struct HashAggregateBindData : public FunctionData
@@ -354,8 +341,20 @@ namespace duckdb
354341 template <class STATE >
355342 static void Initialize (STATE &state)
356343 {
357- state. blake3_hasher = nullptr ;
344+ // So this may fill the state with random garbage
358345 state.ctx = nullptr ;
346+ state.is_touched = false ;
347+ blake3_hasher_init (&state.blake3_hasher );
348+ }
349+
350+ template <class STATE >
351+ static void Destroy (STATE &state, AggregateInputData &aggr_input_data)
352+ {
353+ if (state.ctx )
354+ {
355+ EVP_MD_CTX_free (state.ctx );
356+ state.ctx = nullptr ;
357+ }
359358 }
360359
361360 static bool IgnoreNull ()
@@ -371,56 +370,53 @@ namespace duckdb
371370
372371 if (bind_data.is_blake3 )
373372 {
374- if (!state.blake3_hasher )
373+ if (!state.is_touched )
375374 {
376- state.blake3_hasher = new blake3_hasher ;
377- blake3_hasher_init ( state.blake3_hasher ) ;
375+ blake3_hasher_init (& state.blake3_hasher ) ;
376+ state.is_touched = true ;
378377 }
379378
380379 // hash the record length as well to prevent length extension attacks
381380 if constexpr (std::is_same_v<A_TYPE, string_t >)
382381 {
383382 const uint64_t size = a_data.GetSize ();
384- blake3_hasher_update (state.blake3_hasher , &size, sizeof (uint64_t ));
385- blake3_hasher_update (state.blake3_hasher , a_data.GetDataUnsafe (), size);
383+ blake3_hasher_update (& state.blake3_hasher , &size, sizeof (uint64_t ));
384+ blake3_hasher_update (& state.blake3_hasher , a_data.GetDataUnsafe (), size);
386385 }
387386 else
388387 {
389- blake3_hasher_update (state.blake3_hasher , &a_data, sizeof (a_data));
388+ blake3_hasher_update (& state.blake3_hasher , &a_data, sizeof (a_data));
390389 }
391390 }
392391 else
393392 {
394- if (!state.ctx )
393+ if (!state.is_touched )
395394 {
396395 state.ctx = EVP_MD_CTX_new ();
397396 if (EVP_DigestInit_ex (state.ctx , bind_data.md , nullptr ) != 1 )
398397 {
399- EVP_MD_CTX_free (state.ctx );
400398 throw InternalException (" Failed to initialize hash context" );
401399 }
400+ state.is_touched = true ;
402401 }
403402
404403 if constexpr (std::is_same_v<A_TYPE, string_t >)
405404 {
406405 const uint64_t size = a_data.GetSize ();
407406 if (EVP_DigestUpdate (state.ctx , &size, sizeof (uint64_t )) != 1 )
408407 {
409- EVP_MD_CTX_free (state.ctx );
410408 throw InternalException (" Failed to update hash" );
411409 }
412410
413411 if (EVP_DigestUpdate (state.ctx , a_data.GetDataUnsafe (), size) != 1 )
414412 {
415- EVP_MD_CTX_free (state.ctx );
416413 throw InternalException (" Failed to update hash" );
417414 }
418415 }
419416 else
420417 {
421418 if (EVP_DigestUpdate (state.ctx , &a_data, sizeof (a_data)) != 1 )
422419 {
423- EVP_MD_CTX_free (state.ctx );
424420 throw InternalException (" Failed to update hash" );
425421 }
426422 }
@@ -441,13 +437,21 @@ namespace duckdb
441437 static void Combine (const STATE &source, STATE &target, AggregateInputData &aggr_input_data)
442438 {
443439 auto &bind_data = aggr_input_data.bind_data ->Cast <HashAggregateBindData>();
444- const bool source_used = bind_data. is_blake3 ? ( source.blake3_hasher != nullptr ) : (source. ctx != nullptr ) ;
445- const bool target_used = bind_data. is_blake3 ? ( target.blake3_hasher != nullptr ) : (target. ctx != nullptr ) ;
440+ const bool source_used = source.is_touched ;
441+ const bool target_used = target.is_touched ;
446442
447443 if (source_used && !target_used)
448444 {
449- target.blake3_hasher = source.blake3_hasher ;
450- target.ctx = source.ctx ;
445+ if (bind_data.is_blake3 )
446+ {
447+ target.blake3_hasher = source.blake3_hasher ;
448+ }
449+ else
450+ {
451+ target.ctx = EVP_MD_CTX_new ();
452+ EVP_MD_CTX_copy_ex (target.ctx , source.ctx );
453+ }
454+ target.is_touched = true ;
451455 }
452456 else if (!source_used)
453457 {
@@ -464,19 +468,15 @@ namespace duckdb
464468 auto &bind_data = finalize_data.input .bind_data ->Cast <HashAggregateBindData>();
465469 if (bind_data.is_blake3 )
466470 {
467- if (!state.blake3_hasher )
471+ if (!state.is_touched )
468472 {
469- delete state.blake3_hasher ;
470- state.blake3_hasher = nullptr ;
471473 finalize_data.ReturnNull ();
472474 return ;
473475 }
474476 char output[BLAKE3_OUT_LEN];
475- blake3_hasher_finalize (state.blake3_hasher , reinterpret_cast <uint8_t *>(&output), BLAKE3_OUT_LEN);
477+ blake3_hasher_finalize (& state.blake3_hasher , reinterpret_cast <uint8_t *>(&output), BLAKE3_OUT_LEN);
476478 target = StringVector::AddStringOrBlob (finalize_data.result , reinterpret_cast <const char *>(&output),
477479 BLAKE3_OUT_LEN);
478- delete state.blake3_hasher ;
479- state.blake3_hasher = nullptr ;
480480 }
481481 else
482482 {
@@ -491,11 +491,8 @@ namespace duckdb
491491 // Finalize the hash
492492 if (EVP_DigestFinal_ex (state.ctx , hash_result, &hash_len) != 1 )
493493 {
494- EVP_MD_CTX_free (state.ctx );
495494 throw InternalException (" Failed to finalize hash" );
496495 }
497- EVP_MD_CTX_free (state.ctx );
498- state.ctx = nullptr ;
499496 target = StringVector::AddStringOrBlob (finalize_data.result , reinterpret_cast <const char *>(&hash_result),
500497 hash_len);
501498 }
@@ -549,7 +546,7 @@ namespace duckdb
549546 static void RegisterHashAggType (AggregateFunctionSet &agg_set, const LogicalType &logical_type)
550547 {
551548 auto agg_func =
552- AggregateFunction::UnaryAggregate <HashAggregateState, CPP_TYPE, string_t , HashAggregateOperation<HashAggregateState>>(
549+ AggregateFunction::UnaryAggregateDestructor <HashAggregateState, CPP_TYPE, string_t , HashAggregateOperation<HashAggregateState>>(
553550 logical_type, LogicalType::BLOB);
554551 agg_func.order_dependent = AggregateOrderDependent::ORDER_DEPENDENT;
555552 agg_func.distinct_dependent = AggregateDistinctDependent::DISTINCT_DEPENDENT;
0 commit comments