Skip to content

Commit 76919c6

Browse files
authored
Merge pull request #2606 from ben-schwen/hashtree_trailing_storage
Hashtree uses trailing storage instead of array member hack
2 parents 755a8e0 + fe7df2d commit 76919c6

File tree

1 file changed

+71
-53
lines changed

1 file changed

+71
-53
lines changed

highs/util/HighsHashTree.h

Lines changed: 71 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -391,14 +391,32 @@ class HighsHashTree {
391391

392392
struct BranchNode {
393393
Occupation occupation;
394-
NodePtr child[1];
394+
395+
NodePtr* children() {
396+
return reinterpret_cast<NodePtr*>(this + 1);
397+
}
398+
399+
const NodePtr* children() const {
400+
return reinterpret_cast<const NodePtr*>(this + 1);
401+
}
402+
403+
NodePtr& child(int index) { return children()[index]; }
404+
405+
const NodePtr& child(int index) const { return children()[index]; }
406+
407+
NodePtr* childPtr(int index) { return children() + index; }
408+
409+
const NodePtr* childPtr(int index) const { return children() + index; }
395410
};
396411

412+
static_assert(sizeof(BranchNode) % alignof(NodePtr) == 0,
413+
"BranchNode trailing storage must stay NodePtr aligned");
414+
397415
// allocate branch nodes in multiples of 64 bytes to reduce allocator stress
398416
// with different sizes and reduce reallocations of nodes
399417
static constexpr size_t getBranchNodeSize(int numChilds) {
400-
return (sizeof(BranchNode) + size_t(numChilds - 1) * sizeof(NodePtr) + 63) &
401-
~63;
418+
return (sizeof(BranchNode) + size_t(numChilds) * sizeof(NodePtr) + 63) &
419+
~size_t(63);
402420
};
403421

404422
static BranchNode* createBranchingNode(int numChilds) {
@@ -421,20 +439,19 @@ class HighsHashTree {
421439
size_t rightSize = rightChilds * sizeof(NodePtr);
422440

423441
if (newSize == getBranchNodeSize(location + rightChilds)) {
424-
memmove(&branch->child[location + 1], &branch->child[location],
442+
memmove(branch->childPtr(location + 1), branch->childPtr(location),
425443
rightSize);
426444

427445
return branch;
428446
}
429447

430448
BranchNode* newBranch = (BranchNode*)::operator new(newSize);
431-
// sizeof(Branch) already contains the size for 1 pointer. So we just
432-
// need to add the left and right sizes up for the number of
433-
// additional pointers
434-
size_t leftSize = sizeof(BranchNode) + (location - 1) * sizeof(NodePtr);
449+
// copy the header plus the pointers left of the insertion index into the
450+
// new storage
451+
size_t leftSize = sizeof(BranchNode) + size_t(location) * sizeof(NodePtr);
435452

436453
memcpy(newBranch, branch, leftSize);
437-
memcpy(&newBranch->child[location + 1], &branch->child[location],
454+
memcpy(newBranch->childPtr(location + 1), branch->childPtr(location),
438455
rightSize);
439456

440457
destroyBranchingNode(branch);
@@ -573,7 +590,7 @@ class HighsHashTree {
573590
branch->occupation.num_set_until(static_cast<uint8_t>(pos)) - 1;
574591

575592
do {
576-
if (find_recurse(branch->child[j],
593+
if (find_recurse(branch->child(j),
577594
compute_hash(leaf->entries[i].key()), hashPos + 1,
578595
leaf->entries[i].key()))
579596
return &leaf->entries[i];
@@ -603,7 +620,7 @@ class HighsHashTree {
603620
// threshold
604621
int childEntries = 0;
605622
for (int i = 0; i <= newNumChild; ++i) {
606-
childEntries += branch->child[i].numEntriesEstimate();
623+
childEntries += branch->child(i).numEntriesEstimate();
607624
if (childEntries > kLeafBurstThreshold) break;
608625
}
609626

@@ -617,7 +634,7 @@ class HighsHashTree {
617634
// accesses of nodes that are not in cache.
618635
childEntries = 0;
619636
for (int i = 0; i <= newNumChild; ++i)
620-
childEntries += branch->child[i].numEntries();
637+
childEntries += branch->child(i).numEntries();
621638

622639
// check again if we exceed due to the extremely unlikely case
623640
// of having less than 5 list nodes with together more than 30 entries
@@ -628,28 +645,28 @@ class HighsHashTree {
628645
InnerLeaf<1>* newLeafSize1 = new InnerLeaf<1>;
629646
newNode = newLeafSize1;
630647
for (int i = 0; i <= newNumChild; ++i)
631-
mergeIntoLeaf(newLeafSize1, hashPos, branch->child[i]);
648+
mergeIntoLeaf(newLeafSize1, hashPos, branch->child(i));
632649
break;
633650
}
634651
case 2: {
635652
InnerLeaf<2>* newLeafSize2 = new InnerLeaf<2>;
636653
newNode = newLeafSize2;
637654
for (int i = 0; i <= newNumChild; ++i)
638-
mergeIntoLeaf(newLeafSize2, hashPos, branch->child[i]);
655+
mergeIntoLeaf(newLeafSize2, hashPos, branch->child(i));
639656
break;
640657
}
641658
case 3: {
642659
InnerLeaf<3>* newLeafSize3 = new InnerLeaf<3>;
643660
newNode = newLeafSize3;
644661
for (int i = 0; i <= newNumChild; ++i)
645-
mergeIntoLeaf(newLeafSize3, hashPos, branch->child[i]);
662+
mergeIntoLeaf(newLeafSize3, hashPos, branch->child(i));
646663
break;
647664
}
648665
case 4: {
649666
InnerLeaf<4>* newLeafSize4 = new InnerLeaf<4>;
650667
newNode = newLeafSize4;
651668
for (int i = 0; i <= newNumChild; ++i)
652-
mergeIntoLeaf(newLeafSize4, hashPos, branch->child[i]);
669+
mergeIntoLeaf(newLeafSize4, hashPos, branch->child(i));
653670
break;
654671
}
655672
default:
@@ -667,7 +684,7 @@ class HighsHashTree {
667684
size_t rightSize = (newNumChild - location) * sizeof(NodePtr);
668685
if (newSize == getBranchNodeSize(newNumChild + 1)) {
669686
// allocated size class is the same, so we do not allocate a new node
670-
memmove(&branch->child[location], &branch->child[location + 1],
687+
memmove(branch->childPtr(location), branch->childPtr(location + 1),
671688
rightSize);
672689
newNode = branch;
673690
} else {
@@ -676,10 +693,10 @@ class HighsHashTree {
676693
newNode = compressedBranch;
677694

678695
size_t leftSize =
679-
offsetof(BranchNode, child) + location * sizeof(NodePtr);
696+
sizeof(BranchNode) + size_t(location) * sizeof(NodePtr);
680697
memcpy(compressedBranch, branch, leftSize);
681-
memcpy(&compressedBranch->child[location], &branch->child[location + 1],
682-
rightSize);
698+
memcpy(compressedBranch->childPtr(location),
699+
branch->childPtr(location + 1), rightSize);
683700

684701
destroyBranchingNode(branch);
685702
}
@@ -775,16 +792,16 @@ class HighsHashTree {
775792
branch->occupation = occupation;
776793

777794
if (hashPos + 1 == kMaxDepth) {
778-
for (int i = 0; i < branchSize; ++i) branch->child[i] = nullptr;
795+
for (int i = 0; i < branchSize; ++i) branch->child(i) = nullptr;
779796

780797
for (int i = 0; i < leaf->size; ++i) {
781798
int pos =
782799
occupation.num_set_until(get_first_chunk16(leaf->hashes[i])) -
783800
1;
784-
if (branch->child[pos].getType() == kEmpty)
785-
branch->child[pos] = new ListLeaf(std::move(leaf->entries[i]));
801+
if (branch->child(pos).getType() == kEmpty)
802+
branch->child(pos) = new ListLeaf(std::move(leaf->entries[i]));
786803
else {
787-
ListLeaf* listLeaf = branch->child[pos].getListLeaf();
804+
ListLeaf* listLeaf = branch->child(pos).getListLeaf();
788805
ListNode* newNode = new ListNode(std::move(listLeaf->first));
789806
listLeaf->first.next = newNode;
790807
listLeaf->first.entry = std::move(leaf->entries[i]);
@@ -797,11 +814,11 @@ class HighsHashTree {
797814
ListLeaf* listLeaf;
798815

799816
int pos = occupation.num_set_until(get_hash_chunk(hash, hashPos)) - 1;
800-
if (branch->child[pos].getType() == kEmpty) {
817+
if (branch->child(pos).getType() == kEmpty) {
801818
listLeaf = new ListLeaf(std::move(entry));
802-
branch->child[pos] = listLeaf;
819+
branch->child(pos) = listLeaf;
803820
} else {
804-
listLeaf = branch->child[pos].getListLeaf();
821+
listLeaf = branch->child(pos).getListLeaf();
805822
ListNode* newNode = new ListNode(std::move(listLeaf->first));
806823
listLeaf->first.next = newNode;
807824
listLeaf->first.entry = std::move(entry);
@@ -821,13 +838,13 @@ class HighsHashTree {
821838
if (maxEntriesPerLeaf <= InnerLeaf<1>::capacity()) {
822839
// all items can go into the smallest leaf size
823840
for (int i = 0; i < branchSize; ++i)
824-
branch->child[i] = new InnerLeaf<1>;
841+
branch->child(i) = new InnerLeaf<1>;
825842

826843
for (int i = 0; i < leaf->size; ++i) {
827844
int pos =
828845
occupation.num_set_until(get_first_chunk16(leaf->hashes[i])) -
829846
1;
830-
branch->child[pos].getInnerLeafSizeClass1()->insert_entry(
847+
branch->child(pos).getInnerLeafSizeClass1()->insert_entry(
831848
compute_hash(leaf->entries[i].key()), hashPos + 1,
832849
leaf->entries[i]);
833850
}
@@ -836,7 +853,7 @@ class HighsHashTree {
836853

837854
int pos =
838855
occupation.num_set_until(get_hash_chunk(hash, hashPos)) - 1;
839-
return branch->child[pos].getInnerLeafSizeClass1()->insert_entry(
856+
return branch->child(pos).getInnerLeafSizeClass1()->insert_entry(
840857
hash, hashPos + 1, entry);
841858
} else {
842859
// there are many collisions, determine the exact sizes first
@@ -852,16 +869,16 @@ class HighsHashTree {
852869
for (int i = 0; i < branchSize; ++i) {
853870
switch (entries_to_size_class(sizes[i])) {
854871
case 1:
855-
branch->child[i] = new InnerLeaf<1>;
872+
branch->child(i) = new InnerLeaf<1>;
856873
break;
857874
case 2:
858-
branch->child[i] = new InnerLeaf<2>;
875+
branch->child(i) = new InnerLeaf<2>;
859876
break;
860877
case 3:
861-
branch->child[i] = new InnerLeaf<3>;
878+
branch->child(i) = new InnerLeaf<3>;
862879
break;
863880
case 4:
864-
branch->child[i] = new InnerLeaf<4>;
881+
branch->child(i) = new InnerLeaf<4>;
865882
break;
866883
default:
867884
// Unexpected result from 'entries_to_size_class'
@@ -874,24 +891,24 @@ class HighsHashTree {
874891
occupation.num_set_until(get_first_chunk16(leaf->hashes[i])) -
875892
1;
876893

877-
switch (branch->child[pos].getType()) {
894+
switch (branch->child(pos).getType()) {
878895
case kInnerLeafSizeClass1:
879-
branch->child[pos].getInnerLeafSizeClass1()->insert_entry(
896+
branch->child(pos).getInnerLeafSizeClass1()->insert_entry(
880897
compute_hash(leaf->entries[i].key()), hashPos + 1,
881898
leaf->entries[i]);
882899
break;
883900
case kInnerLeafSizeClass2:
884-
branch->child[pos].getInnerLeafSizeClass2()->insert_entry(
901+
branch->child(pos).getInnerLeafSizeClass2()->insert_entry(
885902
compute_hash(leaf->entries[i].key()), hashPos + 1,
886903
leaf->entries[i]);
887904
break;
888905
case kInnerLeafSizeClass3:
889-
branch->child[pos].getInnerLeafSizeClass3()->insert_entry(
906+
branch->child(pos).getInnerLeafSizeClass3()->insert_entry(
890907
compute_hash(leaf->entries[i].key()), hashPos + 1,
891908
leaf->entries[i]);
892909
break;
893910
case kInnerLeafSizeClass4:
894-
branch->child[pos].getInnerLeafSizeClass4()->insert_entry(
911+
branch->child(pos).getInnerLeafSizeClass4()->insert_entry(
895912
compute_hash(leaf->entries[i].key()), hashPos + 1,
896913
leaf->entries[i]);
897914
break;
@@ -904,14 +921,14 @@ class HighsHashTree {
904921
delete leaf;
905922

906923
int pos = occupation.num_set_until(hashChunk) - 1;
907-
insertNode = &branch->child[pos];
924+
insertNode = branch->childPtr(pos);
908925
++hashPos;
909926
} else {
910927
// extremely unlikely that the new branch node only gets one
911928
// child in that case create it and defer the insertion into
912929
// the next depth
913-
branch->child[0] = leaf;
914-
insertNode = &branch->child[0];
930+
branch->child(0) = leaf;
931+
insertNode = branch->childPtr(0);
915932
++hashPos;
916933
leaf->rehash(hashPos);
917934
}
@@ -930,12 +947,12 @@ class HighsHashTree {
930947
branch = addChildToBranchNode(branch, get_hash_chunk(hash, hashPos),
931948
location);
932949

933-
branch->child[location] = nullptr;
950+
branch->child(location) = nullptr;
934951
branch->occupation.set(get_hash_chunk(hash, hashPos));
935952
}
936953

937954
*insertNode = branch;
938-
insertNode = &branch->child[location];
955+
insertNode = branch->childPtr(location);
939956
++hashPos;
940957
}
941958
}
@@ -1038,9 +1055,9 @@ class HighsHashTree {
10381055

10391056
int location =
10401057
branch->occupation.num_set_until(get_hash_chunk(hash, hashPos)) - 1;
1041-
erase_recurse(&branch->child[location], hash, hashPos + 1, key);
1058+
erase_recurse(branch->childPtr(location), hash, hashPos + 1, key);
10421059

1043-
if (branch->child[location].getType() != kEmpty) return;
1060+
if (branch->child(location).getType() != kEmpty) return;
10441061

10451062
branch->occupation.flip(get_hash_chunk(hash, hashPos));
10461063

@@ -1088,7 +1105,7 @@ class HighsHashTree {
10881105
return nullptr;
10891106
int location =
10901107
branch->occupation.num_set_until(get_hash_chunk(hash, hashPos)) - 1;
1091-
node = branch->child[location];
1108+
node = branch->child(location);
10921109
++hashPos;
10931110
}
10941111
}
@@ -1147,8 +1164,8 @@ class HighsHashTree {
11471164
branch2->occupation.num_set_until(static_cast<uint8_t>(pos)) - 1;
11481165

11491166
const HighsHashTableEntry<K, V>* match =
1150-
find_common_recurse(branch1->child[location1],
1151-
branch2->child[location2], hashPos + 1);
1167+
find_common_recurse(branch1->child(location1),
1168+
branch2->child(location2), hashPos + 1);
11521169
if (match != nullptr) return match;
11531170
}
11541171

@@ -1191,7 +1208,7 @@ class HighsHashTree {
11911208
BranchNode* branch = node.getBranchNode();
11921209
int size = branch->occupation.num_set();
11931210

1194-
for (int i = 0; i < size; ++i) destroy_recurse(branch->child[i]);
1211+
for (int i = 0; i < size; ++i) destroy_recurse(branch->child(i));
11951212

11961213
destroyBranchingNode(branch);
11971214
}
@@ -1240,7 +1257,7 @@ class HighsHashTree {
12401257
(BranchNode*)::operator new(getBranchNodeSize(size));
12411258
newBranch->occupation = branch->occupation;
12421259
for (int i = 0; i < size; ++i)
1243-
newBranch->child[i] = copy_recurse(branch->child[i]);
1260+
newBranch->child(i) = copy_recurse(branch->child(i));
12441261

12451262
return newBranch;
12461263
}
@@ -1292,7 +1309,8 @@ class HighsHashTree {
12921309
BranchNode* branch = node.getBranchNode();
12931310
int size = branch->occupation.num_set();
12941311

1295-
for (int i = 0; i < size; ++i) for_each_recurse<R>(branch->child[i], f);
1312+
for (int i = 0; i < size; ++i)
1313+
for_each_recurse<R>(branch->child(i), f);
12961314
}
12971315
}
12981316
}
@@ -1354,7 +1372,7 @@ class HighsHashTree {
13541372
int size = branch->occupation.num_set();
13551373

13561374
for (int i = 0; i < size; ++i) {
1357-
auto x = for_each_recurse<R>(branch->child[i], f);
1375+
auto x = for_each_recurse<R>(branch->child(i), f);
13581376
if (x) return x;
13591377
}
13601378
}

0 commit comments

Comments
 (0)