Skip to content

Commit 06a2423

Browse files
Abduqodiri Qurbonzodaqurbonzoda
authored andcommitted
PeristentHashMapBuilder.putAll() with another persistent hash map can produce incorrect results #114
When putting all entries of a cell (1) into another cell (2), if (1) is an entry and (2) is a node, for optimization reasons the entry is put into the node. This leads to saving the old value of the entry if the node already contains the key.
1 parent f5f852b commit 06a2423

File tree

2 files changed

+76
-61
lines changed

2 files changed

+76
-61
lines changed

core/commonMain/src/implementations/immutableMap/TrieNode.kt

Lines changed: 66 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -460,38 +460,72 @@ internal class TrieNode<K, V>(
460460
}
461461
}
462462

463-
private fun mutablePutAllFromOtherNodeCell(other: TrieNode<K, V>,
464-
positionMask: Int,
465-
shift: Int,
466-
intersectionCounter: DeltaCounter,
467-
mutator: PersistentHashMapBuilder<K, V>): TrieNode<K, V> {
468-
return when {
469-
other.hasNodeAt(positionMask) -> {
470-
mutablePutAll(
471-
other.nodeAtIndex(other.nodeIndex(positionMask)),
472-
shift + LOG_MAX_BRANCHING_FACTOR,
473-
intersectionCounter,
474-
mutator
475-
)
463+
/**
464+
* Updates the cell of this node at [positionMask] with entries from the cell of [otherNode] at [positionMask].
465+
*/
466+
private fun mutablePutAllFromOtherNodeCell(
467+
otherNode: TrieNode<K, V>,
468+
positionMask: Int,
469+
shift: Int,
470+
intersectionCounter: DeltaCounter,
471+
mutator: PersistentHashMapBuilder<K, V>
472+
): TrieNode<K, V> = when {
473+
this.hasNodeAt(positionMask) -> {
474+
val targetNode = this.nodeAtIndex(nodeIndex(positionMask))
475+
when {
476+
otherNode.hasNodeAt(positionMask) -> {
477+
val otherTargetNode = otherNode.nodeAtIndex(otherNode.nodeIndex(positionMask))
478+
targetNode.mutablePutAll(otherTargetNode, shift + LOG_MAX_BRANCHING_FACTOR, intersectionCounter, mutator)
479+
}
480+
otherNode.hasEntryAt(positionMask) -> {
481+
val keyIndex = otherNode.entryKeyIndex(positionMask)
482+
val key = otherNode.keyAtIndex(keyIndex)
483+
val value = otherNode.valueAtKeyIndex(keyIndex)
484+
val oldSize = mutator.size
485+
targetNode.mutablePut(key.hashCode(), key, value, shift + LOG_MAX_BRANCHING_FACTOR, mutator).also {
486+
if (mutator.size == oldSize) intersectionCounter.count++
487+
}
488+
}
489+
else -> targetNode
476490
}
477-
other.hasEntryAt(positionMask) -> {
478-
val keyIndex = other.entryKeyIndex(positionMask)
479-
val key = other.keyAtIndex(keyIndex)
480-
val value = other.valueAtKeyIndex(keyIndex)
481-
val oldSize = mutator.size
482-
val newNode = mutablePut(
483-
key.hashCode(),
484-
key,
485-
value,
486-
shift + LOG_MAX_BRANCHING_FACTOR,
487-
mutator
488-
)
489-
if (mutator.size == oldSize) {
490-
intersectionCounter.count++
491+
}
492+
493+
otherNode.hasNodeAt(positionMask) -> {
494+
val otherTargetNode = otherNode.nodeAtIndex(otherNode.nodeIndex(positionMask))
495+
when {
496+
this.hasEntryAt(positionMask) -> {
497+
// if otherTargetNode already has a value associated with the key, do not put this entry
498+
val keyIndex = this.entryKeyIndex(positionMask)
499+
val key = this.keyAtIndex(keyIndex)
500+
if (otherTargetNode.containsKey(key.hashCode(), key, shift + LOG_MAX_BRANCHING_FACTOR)) {
501+
intersectionCounter.count++
502+
otherTargetNode
503+
} else {
504+
val value = this.valueAtKeyIndex(keyIndex)
505+
otherTargetNode.mutablePut(key.hashCode(), key, value, shift + LOG_MAX_BRANCHING_FACTOR, mutator)
506+
}
491507
}
492-
newNode
508+
else -> otherTargetNode
493509
}
494-
else -> this
510+
}
511+
512+
else -> { // two entries, and they are not equal by key. See (**) in mutablePutAll
513+
val thisKeyIndex = this.entryKeyIndex(positionMask)
514+
val thisKey = this.keyAtIndex(thisKeyIndex)
515+
val thisValue = this.valueAtKeyIndex(thisKeyIndex)
516+
val otherKeyIndex = otherNode.entryKeyIndex(positionMask)
517+
val otherKey = otherNode.keyAtIndex(otherKeyIndex)
518+
val otherValue = otherNode.valueAtKeyIndex(otherKeyIndex)
519+
makeNode(
520+
thisKey.hashCode(),
521+
thisKey,
522+
thisValue,
523+
otherKey.hashCode(),
524+
otherKey,
525+
otherValue,
526+
shift + LOG_MAX_BRANCHING_FACTOR,
527+
mutator.ownership
528+
)
495529
}
496530
}
497531

@@ -575,7 +609,7 @@ internal class TrieNode<K, V>(
575609
// but not in the new data nodes
576610
var newDataMap = dataMap xor otherNode.dataMap and newNodeMap.inv()
577611
// (**) now, this is tricky: we have a number of entry-entry pairs and we don't know yet whether
578-
// they result in an entry (if they are equal) or a new node (if they are not)
612+
// they result in an entry (if keys are equal) or a new node (if they are not)
579613
// but we want to keep it to single allocation, so we check and mark equal ones here
580614
(dataMap and otherNode.dataMap).forEachOneBit { positionMask, _ ->
581615
val leftKey = this.keyAtIndex(this.entryKeyIndex(positionMask))
@@ -586,7 +620,7 @@ internal class TrieNode<K, V>(
586620
else newNodeMap = newNodeMap or positionMask
587621
// we can use this later to skip calling equals() again
588622
}
589-
assert(newNodeMap and newDataMap == 0)
623+
check(newNodeMap and newDataMap == 0)
590624
val mutableNode = when {
591625
this.ownedBy == mutator.ownership && this.dataMap == newDataMap && this.nodeMap == newNodeMap -> this
592626
else -> {
@@ -596,36 +630,7 @@ internal class TrieNode<K, V>(
596630
}
597631
newNodeMap.forEachOneBit { positionMask, index ->
598632
val newNodeIndex = mutableNode.buffer.size - 1 - index
599-
mutableNode.buffer[newNodeIndex] = when {
600-
hasNodeAt(positionMask) -> {
601-
val before = nodeAtIndex(nodeIndex(positionMask))
602-
before.mutablePutAllFromOtherNodeCell(otherNode, positionMask, shift, intersectionCounter, mutator)
603-
}
604-
605-
otherNode.hasNodeAt(positionMask) -> {
606-
val before = otherNode.nodeAtIndex(otherNode.nodeIndex(positionMask))
607-
before.mutablePutAllFromOtherNodeCell(this, positionMask, shift, intersectionCounter, mutator)
608-
}
609-
610-
else -> { // two entries, and they are not equal by key (see ** above)
611-
val thisKeyIndex = this.entryKeyIndex(positionMask)
612-
val thisKey = this.keyAtIndex(thisKeyIndex)
613-
val thisValue = this.valueAtKeyIndex(thisKeyIndex)
614-
val otherKeyIndex = otherNode.entryKeyIndex(positionMask)
615-
val otherKey = otherNode.keyAtIndex(otherKeyIndex)
616-
val otherValue = otherNode.valueAtKeyIndex(otherKeyIndex)
617-
makeNode(
618-
thisKey.hashCode(),
619-
thisKey,
620-
thisValue,
621-
otherKey.hashCode(),
622-
otherKey,
623-
otherValue,
624-
shift + LOG_MAX_BRANCHING_FACTOR,
625-
mutator.ownership
626-
)
627-
}
628-
}
633+
mutableNode.buffer[newNodeIndex] = mutablePutAllFromOtherNodeCell(otherNode, positionMask, shift, intersectionCounter, mutator)
629634
}
630635
newDataMap.forEachOneBit { positionMask, index ->
631636
val newKeyIndex = index * ENTRY_SIZE

core/commonTest/src/contract/map/ImmutableMapTest.kt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,16 @@ class ImmutableHashMapTest : ImmutableMapTest() {
8787

8888
assertTrue(map1.containsKey(32))
8989
}
90+
91+
@Test
92+
fun regressionGithubIssue114() {
93+
// https://github.com/Kotlin/kotlinx.collections.immutable/issues/114
94+
val p = persistentMapOf(99 to 1)
95+
val e = Array(101) { it }.map { it to it }
96+
val c = persistentMapOf(*e.toTypedArray())
97+
val n = p.builder().apply { putAll(c) }.build()
98+
assertEquals(99, n[99])
99+
}
90100
}
91101

92102
class ImmutableOrderedMapTest : ImmutableMapTest() {

0 commit comments

Comments
 (0)