Skip to content

Commit 02341fe

Browse files
committed
Refactor TrieNode of PersistentHashSet
1 parent 888fbac commit 02341fe

File tree

1 file changed

+110
-92
lines changed
  • kotlinx-collections-immutable/src/main/kotlin/kotlinx/collections/immutable/implementations/immutableSet

1 file changed

+110
-92
lines changed

kotlinx-collections-immutable/src/main/kotlin/kotlinx/collections/immutable/implementations/immutableSet/TrieNode.kt

Lines changed: 110 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2016-2018 JetBrains s.r.o.
2+
* Copyright 2016-2019 JetBrains s.r.o.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -20,59 +20,81 @@ package kotlinx.collections.immutable.implementations.immutableSet
2020
internal const val MAX_BRANCHING_FACTOR = 32
2121
internal const val LOG_MAX_BRANCHING_FACTOR = 5
2222
internal const val MAX_BRANCHING_FACTOR_MINUS_ONE = MAX_BRANCHING_FACTOR - 1
23-
internal const val ENTRY_SIZE = 2
2423
internal const val MAX_SHIFT = 30
2524

26-
27-
internal class TrieNode<E>(var bitmap: Int,
28-
var buffer: Array<Any?>,
29-
var marker: Marker?) {
25+
/**
26+
* Gets trie index segment of the specified [index] at the level specified by [shift].
27+
*
28+
* `shift` equal to zero corresponds to the root level.
29+
* For each lower level `shift` increments by [LOG_MAX_BRANCHING_FACTOR].
30+
*/
31+
internal fun indexSegment(index: Int, shift: Int): Int =
32+
(index shr shift) and MAX_BRANCHING_FACTOR_MINUS_ONE
33+
34+
35+
private fun <E> Array<Any?>.addElementAtIndex(index: Int, element: E): Array<Any?> {
36+
val newBuffer = arrayOfNulls<Any?>(this.size + 1)
37+
this.copyInto(newBuffer, endIndex = index)
38+
this.copyInto(newBuffer, index + 1, index, this.size)
39+
newBuffer[index] = element
40+
return newBuffer
41+
}
42+
43+
private fun Array<Any?>.removeCellAtIndex(cellIndex: Int): Array<Any?> {
44+
val newBuffer = arrayOfNulls<Any?>(this.size - 1)
45+
this.copyInto(newBuffer, endIndex = cellIndex)
46+
this.copyInto(newBuffer, cellIndex, cellIndex + 1, this.size)
47+
return newBuffer
48+
}
49+
50+
internal class TrieNode<E>(
51+
var bitmap: Int,
52+
var buffer: Array<Any?>,
53+
var marker: Marker?
54+
) {
3055

3156
constructor(bitmap: Int, buffer: Array<Any?>) : this(bitmap, buffer, null)
3257

33-
private fun isNullCellAt(position: Int): Boolean {
34-
return bitmap and position == 0
58+
// here and later:
59+
// positionMask — an int in form 2^n, i.e. having the single bit set, whose ordinal is a logical position in buffer
60+
61+
private fun hasNoCellAt(positionMask: Int): Boolean {
62+
return bitmap and positionMask == 0
3563
}
3664

37-
private fun indexOfCellAt(position: Int): Int {
38-
return Integer.bitCount(bitmap and (position - 1))
65+
private fun indexOfCellAt(positionMask: Int): Int {
66+
return Integer.bitCount(bitmap and (positionMask - 1))
3967
}
4068

4169
private fun elementAtIndex(index: Int): E {
70+
@Suppress("UNCHECKED_CAST")
4271
return buffer[index] as E
4372
}
4473

4574
private fun nodeAtIndex(index: Int): TrieNode<E> {
75+
@Suppress("UNCHECKED_CAST")
4676
return buffer[index] as TrieNode<E>
4777
}
4878

49-
private fun bufferAddElementAtIndex(index: Int, element: E): Array<Any?> {
50-
val newBuffer = arrayOfNulls<Any?>(buffer.size + 1)
51-
System.arraycopy(buffer, 0, newBuffer, 0, index)
52-
System.arraycopy(buffer, index, newBuffer, index + 1, buffer.size - index)
53-
newBuffer[index] = element
54-
return newBuffer
55-
}
79+
private fun addElementAt(positionMask: Int, element: E): TrieNode<E> {
80+
// assert(hasNoCellAt(positionMask))
5681

57-
private fun addElementAt(position: Int, element: E): TrieNode<E> {
58-
// assert(isNullCellAt(position))
59-
60-
val index = indexOfCellAt(position)
61-
val newBuffer = bufferAddElementAtIndex(index, element)
62-
return TrieNode(bitmap or position, newBuffer)
82+
val index = indexOfCellAt(positionMask)
83+
val newBuffer = buffer.addElementAtIndex(index, element)
84+
return TrieNode(bitmap or positionMask, newBuffer)
6385
}
6486

65-
private fun mutableAddElementAt(position: Int, element: E, mutatorMarker: Marker): TrieNode<E> {
66-
// assert(isNullCellAt(position))
87+
private fun mutableAddElementAt(positionMask: Int, element: E, mutatorMarker: Marker): TrieNode<E> {
88+
// assert(hasNoCellAt(positionMask))
6789

68-
val index = indexOfCellAt(position)
90+
val index = indexOfCellAt(positionMask)
6991
if (marker === mutatorMarker) {
70-
buffer = bufferAddElementAtIndex(index, element)
71-
bitmap = bitmap or position
92+
buffer = buffer.addElementAtIndex(index, element)
93+
bitmap = bitmap or positionMask
7294
return this
7395
}
74-
val newBuffer = bufferAddElementAtIndex(index, element)
75-
return TrieNode(bitmap or position, newBuffer, mutatorMarker)
96+
val newBuffer = buffer.addElementAtIndex(index, element)
97+
return TrieNode(bitmap or positionMask, newBuffer, mutatorMarker)
7698
}
7799

78100
private fun updateNodeAtIndex(nodeIndex: Int, newNode: TrieNode<E>): TrieNode<E> {
@@ -104,17 +126,13 @@ internal class TrieNode<E>(var bitmap: Int,
104126

105127
private fun moveElementToNode(elementIndex: Int, newElementHash: Int, newElement: E,
106128
shift: Int): TrieNode<E> {
107-
// assert(!isNullCellAt(position))
108-
109129
val newBuffer = buffer.copyOf()
110130
newBuffer[elementIndex] = makeNodeAtIndex(elementIndex, newElementHash, newElement, shift, null)
111131
return TrieNode(bitmap, newBuffer)
112132
}
113133

114134
private fun mutableMoveElementToNode(elementIndex: Int, newElementHash: Int, newElement: E,
115135
shift: Int, mutatorMarker: Marker): TrieNode<E> {
116-
// assert(!isNullCellAt(position))
117-
118136
if (marker === mutatorMarker) {
119137
buffer[elementIndex] = makeNodeAtIndex(elementIndex, newElementHash, newElement, shift, mutatorMarker)
120138
return this
@@ -128,67 +146,63 @@ internal class TrieNode<E>(var bitmap: Int,
128146
shift: Int, mutatorMarker: Marker?): TrieNode<E> {
129147
if (shift > MAX_SHIFT) {
130148
// assert(element1 != element2)
149+
// when two element hashes are entirely equal: the last level subtrie node stores them just as unordered list
131150
return TrieNode<E>(0, arrayOf(element1, element2), mutatorMarker)
132151
}
133152

134-
val setBit1 = (elementHash1 shr shift) and MAX_BRANCHING_FACTOR_MINUS_ONE
135-
val setBit2 = (elementHash2 shr shift) and MAX_BRANCHING_FACTOR_MINUS_ONE
153+
val setBit1 = indexSegment(elementHash1, shift)
154+
val setBit2 = indexSegment(elementHash2, shift)
136155

137156
if (setBit1 != setBit2) {
138-
val nodeBuffer = if (setBit1 < setBit2) {
157+
val nodeBuffer = if (setBit1 < setBit2) {
139158
arrayOf<Any?>(element1, element2)
140159
} else {
141160
arrayOf<Any?>(element2, element1)
142161
}
143162
return TrieNode((1 shl setBit1) or (1 shl setBit2), nodeBuffer, mutatorMarker)
144163
}
164+
// hash segments at the given shift are equal: move these elements into the subtrie
145165
val node = makeNode(elementHash1, element1, elementHash2, element2, shift + LOG_MAX_BRANCHING_FACTOR, mutatorMarker)
146166
return TrieNode<E>(1 shl setBit1, arrayOf(node), mutatorMarker)
147167
}
148168

149-
private fun bufferRemoveCellAtIndex(cellIndex: Int): Array<Any?> {
150-
val newBuffer = arrayOfNulls<Any?>(buffer.size - 1)
151-
System.arraycopy(buffer, 0, newBuffer, 0, cellIndex)
152-
System.arraycopy(buffer, cellIndex + 1, newBuffer, cellIndex, buffer.size - cellIndex - 1)
153-
return newBuffer
154-
}
155169

156-
private fun removeCellAtIndex(cellIndex: Int, position: Int): TrieNode<E>? {
157-
// assert(!isNullCellAt(position))
158-
if (buffer.size == 1) { return null }
170+
private fun removeCellAtIndex(cellIndex: Int, positionMask: Int): TrieNode<E>? {
171+
// assert(!hasNoCellAt(positionMask))
172+
if (buffer.size == 1) return null
159173

160-
val newBuffer = bufferRemoveCellAtIndex(cellIndex)
161-
return TrieNode(bitmap xor position, newBuffer)
174+
val newBuffer = buffer.removeCellAtIndex(cellIndex)
175+
return TrieNode(bitmap xor positionMask, newBuffer)
162176
}
163177

164-
private fun mutableRemoveCellAtIndex(cellIndex: Int, position: Int, mutatorMarker: Marker): TrieNode<E>? {
165-
// assert(!isNullCellAt(position))
166-
if (buffer.size == 1) { return null }
178+
private fun mutableRemoveCellAtIndex(cellIndex: Int, positionMask: Int, mutatorMarker: Marker): TrieNode<E>? {
179+
// assert(!hasNoCellAt(positionMask))
180+
if (buffer.size == 1) return null
167181

168182
if (marker === mutatorMarker) {
169-
buffer = bufferRemoveCellAtIndex(cellIndex)
170-
bitmap = bitmap xor position
183+
buffer = buffer.removeCellAtIndex(cellIndex)
184+
bitmap = bitmap xor positionMask
171185
return this
172186
}
173-
val newBuffer = bufferRemoveCellAtIndex(cellIndex)
174-
return TrieNode(bitmap xor position, newBuffer, mutatorMarker)
187+
val newBuffer = buffer.removeCellAtIndex(cellIndex)
188+
return TrieNode(bitmap xor positionMask, newBuffer, mutatorMarker)
175189
}
176190

177191
private fun collisionRemoveElementAtIndex(i: Int): TrieNode<E>? {
178-
if (buffer.size == 1) { return null }
192+
if (buffer.size == 1) return null
179193

180-
val newBuffer = bufferRemoveCellAtIndex(i)
194+
val newBuffer = buffer.removeCellAtIndex(i)
181195
return TrieNode(0, newBuffer)
182196
}
183197

184198
private fun mutableCollisionRemoveElementAtIndex(i: Int, mutatorMarker: Marker): TrieNode<E>? {
185-
if (buffer.size == 1) { return null }
199+
if (buffer.size == 1) return null
186200

187201
if (marker === mutatorMarker) {
188-
buffer = bufferRemoveCellAtIndex(i)
202+
buffer = buffer.removeCellAtIndex(i)
189203
return this
190204
}
191-
val newBuffer = bufferRemoveCellAtIndex(i)
205+
val newBuffer = buffer.removeCellAtIndex(i)
192206
return TrieNode(0, newBuffer, mutatorMarker)
193207
}
194208

@@ -197,19 +211,19 @@ internal class TrieNode<E>(var bitmap: Int,
197211
}
198212

199213
private fun collisionAdd(element: E): TrieNode<E> {
200-
if (collisionContainsElement(element)) { return this }
201-
val newBuffer = bufferAddElementAtIndex(0, element)
214+
if (collisionContainsElement(element)) return this
215+
val newBuffer = buffer.addElementAtIndex(0, element)
202216
return TrieNode(0, newBuffer)
203217
}
204218

205219
private fun mutableCollisionAdd(element: E, mutator: PersistentHashSetBuilder<*>): TrieNode<E> {
206-
if (collisionContainsElement(element)) { return this }
220+
if (collisionContainsElement(element)) return this
207221
mutator.size++
208222
if (marker === mutator.marker) {
209-
buffer = bufferAddElementAtIndex(0, element)
223+
buffer = buffer.addElementAtIndex(0, element)
210224
return this
211225
}
212-
val newBuffer = bufferAddElementAtIndex(0, element)
226+
val newBuffer = buffer.addElementAtIndex(0, element)
213227
return TrieNode(0, newBuffer, mutator.marker)
214228
}
215229

@@ -231,13 +245,13 @@ internal class TrieNode<E>(var bitmap: Int,
231245
}
232246

233247
fun contains(elementHash: Int, element: E, shift: Int): Boolean {
234-
val cellPosition = 1 shl ((elementHash shr shift) and MAX_BRANCHING_FACTOR_MINUS_ONE)
248+
val cellPositionMask = 1 shl indexSegment(elementHash, shift)
235249

236-
if (isNullCellAt(cellPosition)) { // element is absent
250+
if (hasNoCellAt(cellPositionMask)) { // element is absent
237251
return false
238252
}
239253

240-
val cellIndex = indexOfCellAt(cellPosition)
254+
val cellIndex = indexOfCellAt(cellPositionMask)
241255
if (buffer[cellIndex] is TrieNode<*>) { // element may be in node
242256
val targetNode = nodeAtIndex(cellIndex)
243257
if (shift == MAX_SHIFT) {
@@ -250,32 +264,32 @@ internal class TrieNode<E>(var bitmap: Int,
250264
}
251265

252266
fun add(elementHash: Int, element: E, shift: Int): TrieNode<E> {
253-
val cellPosition = 1 shl ((elementHash shr shift) and MAX_BRANCHING_FACTOR_MINUS_ONE)
267+
val cellPositionMask = 1 shl indexSegment(elementHash, shift)
254268

255-
if (isNullCellAt(cellPosition)) { // element is absent
256-
return addElementAt(cellPosition, element)
269+
if (hasNoCellAt(cellPositionMask)) { // element is absent
270+
return addElementAt(cellPositionMask, element)
257271
}
258272

259-
val cellIndex = indexOfCellAt(cellPosition)
273+
val cellIndex = indexOfCellAt(cellPositionMask)
260274
if (buffer[cellIndex] is TrieNode<*>) { // element may be in node
261275
val targetNode = nodeAtIndex(cellIndex)
262276
val newNode = if (shift == MAX_SHIFT) {
263277
targetNode.collisionAdd(element)
264278
} else {
265279
targetNode.add(elementHash, element, shift + LOG_MAX_BRANCHING_FACTOR)
266280
}
267-
if (targetNode === newNode) { return this }
281+
if (targetNode === newNode) return this
268282
return updateNodeAtIndex(cellIndex, newNode)
269283
}
270284
// element is directly in buffer
271-
if (element == buffer[cellIndex]) { return this }
285+
if (element == buffer[cellIndex]) return this
272286
return moveElementToNode(cellIndex, elementHash, element, shift)
273287
}
274288

275289
fun mutableAdd(elementHash: Int, element: E, shift: Int, mutator: PersistentHashSetBuilder<*>): TrieNode<E> {
276-
val cellPosition = 1 shl ((elementHash shr shift) and MAX_BRANCHING_FACTOR_MINUS_ONE)
290+
val cellPosition = 1 shl indexSegment(elementHash, shift)
277291

278-
if (isNullCellAt(cellPosition)) { // element is absent
292+
if (hasNoCellAt(cellPosition)) { // element is absent
279293
mutator.size++
280294
return mutableAddElementAt(cellPosition, element, mutator.marker)
281295
}
@@ -288,64 +302,68 @@ internal class TrieNode<E>(var bitmap: Int,
288302
} else {
289303
targetNode.mutableAdd(elementHash, element, shift + LOG_MAX_BRANCHING_FACTOR, mutator)
290304
}
291-
if (targetNode === newNode) { return this }
305+
if (targetNode === newNode) return this
292306
return mutableUpdateNodeAtIndex(cellIndex, newNode, mutator.marker)
293307
}
294308
// element is directly in buffer
295-
if (element == buffer[cellIndex]) { return this }
309+
if (element == buffer[cellIndex]) return this
296310
mutator.size++
297311
return mutableMoveElementToNode(cellIndex, elementHash, element, shift, mutator.marker)
298312
}
299313

300314
fun remove(elementHash: Int, element: E, shift: Int): TrieNode<E>? {
301-
val cellPosition = 1 shl ((elementHash shr shift) and MAX_BRANCHING_FACTOR_MINUS_ONE)
315+
val cellPositionMask = 1 shl indexSegment(elementHash, shift)
302316

303-
if (isNullCellAt(cellPosition)) { // element is absent
317+
if (hasNoCellAt(cellPositionMask)) { // element is absent
304318
return this
305319
}
306320

307-
val cellIndex = indexOfCellAt(cellPosition)
321+
val cellIndex = indexOfCellAt(cellPositionMask)
308322
if (buffer[cellIndex] is TrieNode<*>) { // element may be in node
309323
val targetNode = nodeAtIndex(cellIndex)
310324
val newNode = if (shift == MAX_SHIFT) {
311325
targetNode.collisionRemove(element)
312326
} else {
313327
targetNode.remove(elementHash, element, shift + LOG_MAX_BRANCHING_FACTOR)
314328
}
315-
if (targetNode === newNode) { return this }
316-
if (newNode == null) { return removeCellAtIndex(cellIndex, cellPosition) }
317-
return updateNodeAtIndex(cellIndex, newNode)
329+
return when {
330+
targetNode === newNode -> this
331+
newNode == null -> removeCellAtIndex(cellIndex, cellPositionMask)
332+
else -> updateNodeAtIndex(cellIndex, newNode)
333+
}
318334
}
319335
// element is directly in buffer
320336
if (element == buffer[cellIndex]) {
321-
return removeCellAtIndex(cellIndex, cellPosition)
337+
return removeCellAtIndex(cellIndex, cellPositionMask)
322338
}
323339
return this
324340
}
325341

326342
fun mutableRemove(elementHash: Int, element: E, shift: Int, mutator: PersistentHashSetBuilder<*>): TrieNode<E>? {
327-
val cellPosition = 1 shl ((elementHash shr shift) and MAX_BRANCHING_FACTOR_MINUS_ONE)
343+
val cellPositionMask = 1 shl indexSegment(elementHash, shift)
328344

329-
if (isNullCellAt(cellPosition)) { // element is absent
345+
if (hasNoCellAt(cellPositionMask)) { // element is absent
330346
return this
331347
}
332348

333-
val cellIndex = indexOfCellAt(cellPosition)
349+
val cellIndex = indexOfCellAt(cellPositionMask)
334350
if (buffer[cellIndex] is TrieNode<*>) { // element may be in node
335351
val targetNode = nodeAtIndex(cellIndex)
336352
val newNode = if (shift == MAX_SHIFT) {
337353
targetNode.mutableCollisionRemove(element, mutator)
338354
} else {
339355
targetNode.mutableRemove(elementHash, element, shift + LOG_MAX_BRANCHING_FACTOR, mutator)
340356
}
341-
if (targetNode === newNode) { return this }
342-
if (newNode == null) { return mutableRemoveCellAtIndex(cellIndex, cellPosition, mutator.marker) }
343-
return mutableUpdateNodeAtIndex(cellIndex, newNode, mutator.marker)
357+
return when {
358+
targetNode === newNode -> this
359+
newNode == null -> mutableRemoveCellAtIndex(cellIndex, cellPositionMask, mutator.marker)
360+
else -> mutableUpdateNodeAtIndex(cellIndex, newNode, mutator.marker)
361+
}
344362
}
345363
// element is directly in buffer
346364
if (element == buffer[cellIndex]) {
347365
mutator.size--
348-
return mutableRemoveCellAtIndex(cellIndex, cellPosition, mutator.marker) // check is empty
366+
return mutableRemoveCellAtIndex(cellIndex, cellPositionMask, mutator.marker) // check is empty
349367
}
350368
return this
351369
}

0 commit comments

Comments
 (0)