@@ -460,38 +460,72 @@ internal class TrieNode<K, V>(
460
460
}
461
461
}
462
462
463
- private fun mutablePutAllFromOtherNodeCell (other : TrieNode <K , V >,
464
- positionMask : Int ,
465
- shift : Int ,
466
- intersectionCounter : DeltaCounter ,
467
- mutator : PersistentHashMapBuilder <K , V >): TrieNode <K , V > {
468
- return when {
469
- other.hasNodeAt(positionMask) -> {
470
- mutablePutAll(
471
- other.nodeAtIndex(other.nodeIndex(positionMask)),
472
- shift + LOG_MAX_BRANCHING_FACTOR ,
473
- intersectionCounter,
474
- mutator
475
- )
463
+ /* *
464
+ * Updates the cell of this node at [positionMask] with entries from the cell of [otherNode] at [positionMask].
465
+ */
466
+ private fun mutablePutAllFromOtherNodeCell (
467
+ otherNode : TrieNode <K , V >,
468
+ positionMask : Int ,
469
+ shift : Int ,
470
+ intersectionCounter : DeltaCounter ,
471
+ mutator : PersistentHashMapBuilder <K , V >
472
+ ): TrieNode <K , V > = when {
473
+ this .hasNodeAt(positionMask) -> {
474
+ val targetNode = this .nodeAtIndex(nodeIndex(positionMask))
475
+ when {
476
+ otherNode.hasNodeAt(positionMask) -> {
477
+ val otherTargetNode = otherNode.nodeAtIndex(otherNode.nodeIndex(positionMask))
478
+ targetNode.mutablePutAll(otherTargetNode, shift + LOG_MAX_BRANCHING_FACTOR , intersectionCounter, mutator)
479
+ }
480
+ otherNode.hasEntryAt(positionMask) -> {
481
+ val keyIndex = otherNode.entryKeyIndex(positionMask)
482
+ val key = otherNode.keyAtIndex(keyIndex)
483
+ val value = otherNode.valueAtKeyIndex(keyIndex)
484
+ val oldSize = mutator.size
485
+ targetNode.mutablePut(key.hashCode(), key, value, shift + LOG_MAX_BRANCHING_FACTOR , mutator).also {
486
+ if (mutator.size == oldSize) intersectionCounter.count++
487
+ }
488
+ }
489
+ else -> targetNode
476
490
}
477
- other.hasEntryAt(positionMask) -> {
478
- val keyIndex = other.entryKeyIndex(positionMask)
479
- val key = other.keyAtIndex(keyIndex)
480
- val value = other.valueAtKeyIndex(keyIndex)
481
- val oldSize = mutator.size
482
- val newNode = mutablePut(
483
- key.hashCode(),
484
- key,
485
- value,
486
- shift + LOG_MAX_BRANCHING_FACTOR ,
487
- mutator
488
- )
489
- if (mutator.size == oldSize) {
490
- intersectionCounter.count++
491
+ }
492
+
493
+ otherNode.hasNodeAt(positionMask) -> {
494
+ val otherTargetNode = otherNode.nodeAtIndex(otherNode.nodeIndex(positionMask))
495
+ when {
496
+ this .hasEntryAt(positionMask) -> {
497
+ // if otherTargetNode already has a value associated with the key, do not put this entry
498
+ val keyIndex = this .entryKeyIndex(positionMask)
499
+ val key = this .keyAtIndex(keyIndex)
500
+ if (otherTargetNode.containsKey(key.hashCode(), key, shift + LOG_MAX_BRANCHING_FACTOR )) {
501
+ intersectionCounter.count++
502
+ otherTargetNode
503
+ } else {
504
+ val value = this .valueAtKeyIndex(keyIndex)
505
+ otherTargetNode.mutablePut(key.hashCode(), key, value, shift + LOG_MAX_BRANCHING_FACTOR , mutator)
506
+ }
491
507
}
492
- newNode
508
+ else -> otherTargetNode
493
509
}
494
- else -> this
510
+ }
511
+
512
+ else -> { // two entries, and they are not equal by key. See (**) in mutablePutAll
513
+ val thisKeyIndex = this .entryKeyIndex(positionMask)
514
+ val thisKey = this .keyAtIndex(thisKeyIndex)
515
+ val thisValue = this .valueAtKeyIndex(thisKeyIndex)
516
+ val otherKeyIndex = otherNode.entryKeyIndex(positionMask)
517
+ val otherKey = otherNode.keyAtIndex(otherKeyIndex)
518
+ val otherValue = otherNode.valueAtKeyIndex(otherKeyIndex)
519
+ makeNode(
520
+ thisKey.hashCode(),
521
+ thisKey,
522
+ thisValue,
523
+ otherKey.hashCode(),
524
+ otherKey,
525
+ otherValue,
526
+ shift + LOG_MAX_BRANCHING_FACTOR ,
527
+ mutator.ownership
528
+ )
495
529
}
496
530
}
497
531
@@ -575,7 +609,7 @@ internal class TrieNode<K, V>(
575
609
// but not in the new data nodes
576
610
var newDataMap = dataMap xor otherNode.dataMap and newNodeMap.inv ()
577
611
// (**) now, this is tricky: we have a number of entry-entry pairs and we don't know yet whether
578
- // they result in an entry (if they are equal) or a new node (if they are not)
612
+ // they result in an entry (if keys are equal) or a new node (if they are not)
579
613
// but we want to keep it to single allocation, so we check and mark equal ones here
580
614
(dataMap and otherNode.dataMap).forEachOneBit { positionMask, _ ->
581
615
val leftKey = this .keyAtIndex(this .entryKeyIndex(positionMask))
@@ -586,7 +620,7 @@ internal class TrieNode<K, V>(
586
620
else newNodeMap = newNodeMap or positionMask
587
621
// we can use this later to skip calling equals() again
588
622
}
589
- assert (newNodeMap and newDataMap == 0 )
623
+ check (newNodeMap and newDataMap == 0 )
590
624
val mutableNode = when {
591
625
this .ownedBy == mutator.ownership && this .dataMap == newDataMap && this .nodeMap == newNodeMap -> this
592
626
else -> {
@@ -596,36 +630,7 @@ internal class TrieNode<K, V>(
596
630
}
597
631
newNodeMap.forEachOneBit { positionMask, index ->
598
632
val newNodeIndex = mutableNode.buffer.size - 1 - index
599
- mutableNode.buffer[newNodeIndex] = when {
600
- hasNodeAt(positionMask) -> {
601
- val before = nodeAtIndex(nodeIndex(positionMask))
602
- before.mutablePutAllFromOtherNodeCell(otherNode, positionMask, shift, intersectionCounter, mutator)
603
- }
604
-
605
- otherNode.hasNodeAt(positionMask) -> {
606
- val before = otherNode.nodeAtIndex(otherNode.nodeIndex(positionMask))
607
- before.mutablePutAllFromOtherNodeCell(this , positionMask, shift, intersectionCounter, mutator)
608
- }
609
-
610
- else -> { // two entries, and they are not equal by key (see ** above)
611
- val thisKeyIndex = this .entryKeyIndex(positionMask)
612
- val thisKey = this .keyAtIndex(thisKeyIndex)
613
- val thisValue = this .valueAtKeyIndex(thisKeyIndex)
614
- val otherKeyIndex = otherNode.entryKeyIndex(positionMask)
615
- val otherKey = otherNode.keyAtIndex(otherKeyIndex)
616
- val otherValue = otherNode.valueAtKeyIndex(otherKeyIndex)
617
- makeNode(
618
- thisKey.hashCode(),
619
- thisKey,
620
- thisValue,
621
- otherKey.hashCode(),
622
- otherKey,
623
- otherValue,
624
- shift + LOG_MAX_BRANCHING_FACTOR ,
625
- mutator.ownership
626
- )
627
- }
628
- }
633
+ mutableNode.buffer[newNodeIndex] = mutablePutAllFromOtherNodeCell(otherNode, positionMask, shift, intersectionCounter, mutator)
629
634
}
630
635
newDataMap.forEachOneBit { positionMask, index ->
631
636
val newKeyIndex = index * ENTRY_SIZE
0 commit comments