Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions internal/tree/red_black_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,24 @@ func (rb *RBTree[K, V]) KeyValues() ([]K, []V) {
if rb.root == nil {
return keys, values
}
rb.inOrderTraversal(func(node *rbNode[K, V]) {
rb.inOrderTraversal(func(node *rbNode[K, V]) bool {
keys = append(keys, node.key)
values = append(values, node.value)
return true
})
return keys, values
}

// Iterate 按照key的顺序遍历并执行cb,如果cb返回值为false则结束遍历,否则继续遍历
func (rb *RBTree[K, V]) Iterate(cb func(key K, value V) bool) {
rb.inOrderTraversal(func(node *rbNode[K, V]) bool {
return cb(node.key, node.value)
})
}

// inOrderTraversal 中序遍历
func (rb *RBTree[K, V]) inOrderTraversal(visit func(node *rbNode[K, V])) {
stack := make([]*rbNode[K, V], 0, rb.size)
func (rb *RBTree[K, V]) inOrderTraversal(visit func(node *rbNode[K, V]) bool) {
stack := make([]*rbNode[K, V], 0)
curr := rb.root
for curr != nil || len(stack) > 0 {
for curr != nil {
Expand All @@ -136,7 +144,9 @@ func (rb *RBTree[K, V]) inOrderTraversal(visit func(node *rbNode[K, V])) {
}
curr = stack[len(stack)-1]
stack = stack[:len(stack)-1]
visit(curr)
if !visit(curr) {
break
}
curr = curr.right
}
}
Expand Down
67 changes: 67 additions & 0 deletions internal/tree/red_black_tree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,73 @@ func TestRBTree_KeyValues(t *testing.T) {
}
}

func TestRBTree_Iterate(t *testing.T) {
for _, testCase := range []struct {
name string
expectedLen int
inputStart int
inputEnd int
// 如果为true则遍历结束
endConditionFunc func(key int) bool
}{
{
name: "treeMap为空",
expectedLen: 0,
inputStart: 1,
inputEnd: 0,
endConditionFunc: func(key int) bool {
return false
},
},
{
name: "treeMap 有10000个元素,遍历所有小于等于8000的元素",
expectedLen: 8000,
inputStart: 1,
inputEnd: 10000,
endConditionFunc: func(key int) bool {
return key > 8000
},
},
{
name: "treeMap 有10000个元素,遍历所有元素",
expectedLen: 10000,
inputStart: 1,
inputEnd: 10000,
endConditionFunc: func(key int) bool {
return false
},
},
{
name: "treeMap 有10个元素,由于第一个就不符合条件所以遍历立刻中断",
expectedLen: 0,
inputStart: 1,
inputEnd: 10,
endConditionFunc: func(key int) bool {
return key < 5
},
},
} {
t.Run(testCase.name, func(t *testing.T) {
rbTree := NewRBTree[int, int](compare())
for i := testCase.inputStart; i <= testCase.inputEnd; i++ {
assert.Nil(t, rbTree.Add(i, i))
}
arr := make([]int, 0)
rbTree.Iterate(func(key, value int) bool {
if testCase.endConditionFunc(key) {
return false
}
arr = append(arr, value)
return true
})
assert.Equal(t, testCase.expectedLen, len(arr))
for i := 0; i < testCase.expectedLen; i++ {
assert.Equal(t, testCase.inputStart+i, arr[i])
}
})
}
}

// IsRedBlackTree 检测是否满足红黑树
func IsRedBlackTree[K any, V any](root *rbNode[K, V]) bool {
// 检测节点是否黑色
Expand Down
5 changes: 5 additions & 0 deletions mapx/treemap.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,9 @@ func (treeMap *TreeMap[T, V]) Len() int64 {
return int64(treeMap.tree.Size())
}

// Iterate 按照key的顺序遍历并执行cb,如果cb返回值为false则结束遍历,否则继续遍历
func (treeMap *TreeMap[K, V]) Iterate(cb func(key K, value V) bool) {
treeMap.tree.Iterate(cb)
}

var _ mapi[any, any] = (*TreeMap[any, any])(nil)
23 changes: 23 additions & 0 deletions mapx/treemap_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,26 @@ func ExampleNewTreeMap() {
// Output:
// 11
}

func ExampleTreeMap_Iterate() {
m, _ := mapx.NewTreeMap[int, int](ekit.ComparatorRealNumber[int])
_ = m.Put(1, 11)
_ = m.Put(-1, 12)
_ = m.Put(100, 13)
_ = m.Put(-100, 14)
_ = m.Put(-101, 15)

m.Iterate(func(key, value int) bool {
if key > 1 {
return false
}
fmt.Println(key, value)
return true
})

// Output:
// -101 15
// -100 14
// -1 12
// 1 11
}
68 changes: 68 additions & 0 deletions mapx/treemap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,74 @@ func TestTreeMap_Len(t *testing.T) {
}
}

func TestRBTree_Iterate(t *testing.T) {
for _, testCase := range []struct {
name string
expectedLen int
inputStart int
inputEnd int
// 如果为true则遍历结束
endConditionFunc func(key int) bool
}{
{
name: "treeMap为空",
expectedLen: 0,
inputStart: 1,
inputEnd: 0,
endConditionFunc: func(key int) bool {
return false
},
},
{
name: "treeMap 有10000个元素,遍历所有小于等于8000的元素",
expectedLen: 8000,
inputStart: 1,
inputEnd: 10000,
endConditionFunc: func(key int) bool {
return key > 8000
},
},
{
name: "treeMap 有10000个元素,遍历所有元素",
expectedLen: 10000,
inputStart: 1,
inputEnd: 10000,
endConditionFunc: func(key int) bool {
return false
},
},
{
name: "treeMap 有10个元素,由于第一个就不符合条件所以遍历立刻中断",
expectedLen: 0,
inputStart: 1,
inputEnd: 10,
endConditionFunc: func(key int) bool {
return key < 5
},
},
} {
t.Run(testCase.name, func(t *testing.T) {
treeMap, err := NewTreeMap[int, int](compare())
assert.Nil(t, err)
for i := testCase.inputStart; i <= testCase.inputEnd; i++ {
assert.Nil(t, treeMap.Put(i, i))
}
arr := make([]int, 0)
treeMap.Iterate(func(key, value int) bool {
if testCase.endConditionFunc(key) {
return false
}
arr = append(arr, value)
return true
})
assert.Equal(t, testCase.expectedLen, len(arr))
for i := 0; i < testCase.expectedLen; i++ {
assert.Equal(t, testCase.inputStart+i, arr[i])
}
})
}
}

func compare() ekit.Comparator[int] {
return ekit.ComparatorRealNumber[int]
}
Expand Down
Loading