|
1 | 1 | package groups |
2 | 2 |
|
3 | 3 | import ( |
4 | | - "github.com/pkg/errors" |
5 | | - "github.com/stretchr/testify/assert" |
6 | 4 | "sync" |
7 | 5 | "sync/atomic" |
8 | 6 | "testing" |
9 | 7 | "time" |
| 8 | + |
| 9 | + "github.com/pkg/errors" |
| 10 | + "github.com/stretchr/testify/assert" |
10 | 11 | ) |
11 | 12 |
|
12 | 13 | func TestNewTaskGroup(t *testing.T) { |
@@ -103,4 +104,136 @@ func TestNewTaskGroup(t *testing.T) { |
103 | 104 | as.Error(err) |
104 | 105 | as.ElementsMatch(arr, []int{1, 3}) |
105 | 106 | }) |
| 107 | + |
| 108 | + t.Run("push after start", func(t *testing.T) { |
| 109 | + // 测试在Start之后Push任务的行为 |
| 110 | + ctl := New[int](WithConcurrency(2)) |
| 111 | + ctl.Push(1, 2) |
| 112 | + var processed = make([]int, 0) |
| 113 | + var mu = sync.Mutex{} |
| 114 | + ctl.OnMessage = func(args int) error { |
| 115 | + mu.Lock() |
| 116 | + processed = append(processed, args) |
| 117 | + mu.Unlock() |
| 118 | + time.Sleep(10 * time.Millisecond) |
| 119 | + return nil |
| 120 | + } |
| 121 | + |
| 122 | + go func() { |
| 123 | + time.Sleep(5 * time.Millisecond) |
| 124 | + ctl.Push(3, 4) // 在Start之后Push |
| 125 | + }() |
| 126 | + |
| 127 | + err := ctl.Start() |
| 128 | + as.NoError(err) |
| 129 | + // 验证所有任务都被处理 |
| 130 | + as.GreaterOrEqual(len(processed), 2) |
| 131 | + }) |
| 132 | + |
| 133 | + t.Run("multiple errors", func(t *testing.T) { |
| 134 | + // 测试多个任务都出错的情况 |
| 135 | + ctl := New[int](WithConcurrency(2)) |
| 136 | + ctl.Push(1, 2, 3) |
| 137 | + ctl.OnMessage = func(args int) error { |
| 138 | + return errors.Errorf("error %d", args) |
| 139 | + } |
| 140 | + err := ctl.Start() |
| 141 | + as.Error(err) |
| 142 | + // 验证错误被正确收集 |
| 143 | + as.Contains(err.Error(), "error") |
| 144 | + }) |
| 145 | + |
| 146 | + t.Run("concurrent push", func(t *testing.T) { |
| 147 | + // 测试并发Push |
| 148 | + ctl := New[int](WithConcurrency(4)) |
| 149 | + var wg sync.WaitGroup |
| 150 | + for i := 0; i < 10; i++ { |
| 151 | + wg.Add(1) |
| 152 | + go func(n int) { |
| 153 | + defer wg.Done() |
| 154 | + ctl.Push(n) |
| 155 | + }(i) |
| 156 | + } |
| 157 | + wg.Wait() |
| 158 | + |
| 159 | + var processed = int64(0) |
| 160 | + ctl.OnMessage = func(args int) error { |
| 161 | + atomic.AddInt64(&processed, 1) |
| 162 | + return nil |
| 163 | + } |
| 164 | + err := ctl.Start() |
| 165 | + as.NoError(err) |
| 166 | + as.Equal(int64(10), processed) |
| 167 | + }) |
| 168 | + |
| 169 | + t.Run("zero concurrency", func(t *testing.T) { |
| 170 | + // 测试并发度为0的情况(应该使用默认值) |
| 171 | + ctl := New[int](WithConcurrency(0)) |
| 172 | + ctl.Push(1, 2, 3) |
| 173 | + var processed = int64(0) |
| 174 | + ctl.OnMessage = func(args int) error { |
| 175 | + atomic.AddInt64(&processed, 1) |
| 176 | + return nil |
| 177 | + } |
| 178 | + err := ctl.Start() |
| 179 | + as.NoError(err) |
| 180 | + as.Equal(int64(3), processed) |
| 181 | + }) |
| 182 | + |
| 183 | + t.Run("cancel before start", func(t *testing.T) { |
| 184 | + // 测试在Start之前Cancel |
| 185 | + ctl := New[int]() |
| 186 | + ctl.Push(1, 2, 3) |
| 187 | + ctl.Cancel() |
| 188 | + var processed = int64(0) |
| 189 | + ctl.OnMessage = func(args int) error { |
| 190 | + atomic.AddInt64(&processed, 1) |
| 191 | + return nil |
| 192 | + } |
| 193 | + _ = ctl.Start() |
| 194 | + // Cancel后任务应该不会执行 |
| 195 | + as.Equal(int64(0), processed) |
| 196 | + // 由于任务被取消,可能会超时或返回错误 |
| 197 | + }) |
| 198 | + |
| 199 | + t.Run("update deadlock check", func(t *testing.T) { |
| 200 | + // 测试Update方法不会导致死锁 |
| 201 | + ctl := New[int](WithConcurrency(2)) |
| 202 | + ctl.Push(1, 2, 3, 4, 5) |
| 203 | + var processed = int64(0) |
| 204 | + ctl.OnMessage = func(args int) error { |
| 205 | + ctl.Update(func() { |
| 206 | + processed++ |
| 207 | + }) |
| 208 | + return nil |
| 209 | + } |
| 210 | + err := ctl.Start() |
| 211 | + as.NoError(err) |
| 212 | + as.Equal(int64(5), processed) |
| 213 | + }) |
| 214 | + |
| 215 | + t.Run("task count accuracy", func(t *testing.T) { |
| 216 | + // 测试任务计数准确性 |
| 217 | + ctl := New[int](WithConcurrency(2)) |
| 218 | + ctl.Push(1, 2, 3, 4, 5) |
| 219 | + var processed = int64(0) |
| 220 | + ctl.OnMessage = func(args int) error { |
| 221 | + atomic.AddInt64(&processed, 1) |
| 222 | + return nil |
| 223 | + } |
| 224 | + err := ctl.Start() |
| 225 | + as.NoError(err) |
| 226 | + as.Equal(int64(5), processed) |
| 227 | + }) |
| 228 | + |
| 229 | + t.Run("len accuracy", func(t *testing.T) { |
| 230 | + // 测试Len方法的准确性 |
| 231 | + ctl := New[int]() |
| 232 | + as.Equal(0, ctl.Len()) |
| 233 | + ctl.Push(1, 2, 3) |
| 234 | + as.Equal(3, ctl.Len()) |
| 235 | + ctl.Start() |
| 236 | + // Start后队列应该为空 |
| 237 | + as.Equal(0, ctl.Len()) |
| 238 | + }) |
106 | 239 | } |
0 commit comments