|
| 1 | +/*************************************************************** |
| 2 | + * |
| 3 | + * Copyright (C) 2026, Pelican Project, Morgridge Institute for Research |
| 4 | + * |
| 5 | + * Licensed under the Apache License, Version 2.0 (the "License"); you |
| 6 | + * may not use this file except in compliance with the License. You may |
| 7 | + * obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + * |
| 17 | + ***************************************************************/ |
| 18 | + |
| 19 | +package htb |
| 20 | + |
| 21 | +import ( |
| 22 | + "context" |
| 23 | + "sync" |
| 24 | + "testing" |
| 25 | + "time" |
| 26 | + |
| 27 | + "github.com/stretchr/testify/require" |
| 28 | +) |
| 29 | + |
| 30 | +// TestHighConcurrencyNoDeadlock tests that high concurrency doesn't cause deadlock |
| 31 | +// This test specifically addresses the issue where C16 was deadlocking |
| 32 | +func TestHighConcurrencyNoDeadlock(t *testing.T) { |
| 33 | + h := New(1000*1000*1000, 1000*1000*1000) // 1 second capacity (large enough for test) |
| 34 | + |
| 35 | + numWorkers := 20 |
| 36 | + opsPerWorker := 50 |
| 37 | + |
| 38 | + done := make(chan bool, numWorkers) |
| 39 | + start := make(chan struct{}) |
| 40 | + |
| 41 | + // Launch workers |
| 42 | + for i := 0; i < numWorkers; i++ { |
| 43 | + go func(workerID int) { |
| 44 | + userID := "user" + string(rune('A'+workerID)) |
| 45 | + <-start // Wait for signal to start |
| 46 | + |
| 47 | + for j := 0; j < opsPerWorker; j++ { |
| 48 | + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) |
| 49 | + |
| 50 | + // Request small amount of tokens (1ms each) |
| 51 | + tokens, err := h.Wait(ctx, userID, 1*1000*1000) // 1ms |
| 52 | + if err != nil { |
| 53 | + t.Errorf("Worker %d op %d: %v", workerID, j, err) |
| 54 | + cancel() |
| 55 | + done <- false |
| 56 | + return |
| 57 | + } |
| 58 | + |
| 59 | + // Simulate very brief work (in-memory operation) |
| 60 | + // Don't actually sleep to keep test fast |
| 61 | + |
| 62 | + // Return tokens immediately |
| 63 | + tokens.Use(500 * 1000) // Use half (0.5ms) |
| 64 | + h.Return(tokens) // Return 0.5ms |
| 65 | + |
| 66 | + cancel() |
| 67 | + } |
| 68 | + |
| 69 | + done <- true |
| 70 | + }(i) |
| 71 | + } |
| 72 | + |
| 73 | + // Start all workers simultaneously |
| 74 | + close(start) |
| 75 | + |
| 76 | + // Wait for all workers with timeout |
| 77 | + timeout := time.After(30 * time.Second) |
| 78 | + for i := 0; i < numWorkers; i++ { |
| 79 | + select { |
| 80 | + case success := <-done: |
| 81 | + if !success { |
| 82 | + t.Fatal("Worker failed") |
| 83 | + } |
| 84 | + case <-timeout: |
| 85 | + t.Fatal("Deadlock detected: test timed out waiting for workers") |
| 86 | + } |
| 87 | + } |
| 88 | +} |
| 89 | + |
| 90 | +// TestWaiterWakeupOnReturn verifies that returning tokens wakes up waiters |
| 91 | +func TestWaiterWakeupOnReturn(t *testing.T) { |
| 92 | + h := New(10*1000*1000, 10*1000*1000) // 10ms capacity |
| 93 | + |
| 94 | + ctx := context.Background() |
| 95 | + |
| 96 | + // Take all tokens |
| 97 | + tokens1, err := h.Wait(ctx, "user1", 10*1000*1000) |
| 98 | + require.NoError(t, err) |
| 99 | + |
| 100 | + // Try to take more - should block |
| 101 | + waitDone := make(chan bool) |
| 102 | + go func() { |
| 103 | + ctx2, cancel := context.WithTimeout(context.Background(), 2*time.Second) |
| 104 | + defer cancel() |
| 105 | + tokens2, err := h.Wait(ctx2, "user2", 5*1000*1000) // 5ms |
| 106 | + if err != nil { |
| 107 | + t.Error("Wait failed:", err) |
| 108 | + waitDone <- false |
| 109 | + return |
| 110 | + } |
| 111 | + h.Return(tokens2) |
| 112 | + waitDone <- true |
| 113 | + }() |
| 114 | + |
| 115 | + // Wait a bit to ensure goroutine is blocked |
| 116 | + time.Sleep(100 * time.Millisecond) |
| 117 | + |
| 118 | + // Return tokens - should wake up the waiter |
| 119 | + tokens1.Use(5 * 1000 * 1000) // Use 5ms |
| 120 | + h.Return(tokens1) // Return 5ms |
| 121 | + |
| 122 | + // Waiter should now proceed |
| 123 | + select { |
| 124 | + case success := <-waitDone: |
| 125 | + if !success { |
| 126 | + t.Fatal("Waiter failed") |
| 127 | + } |
| 128 | + case <-time.After(1 * time.Second): |
| 129 | + t.Fatal("Waiter was not woken up by Return()") |
| 130 | + } |
| 131 | +} |
| 132 | + |
| 133 | +// TestMultipleWaitersWakeup tests that multiple waiters can be woken up |
| 134 | +func TestMultipleWaitersWakeup(t *testing.T) { |
| 135 | + h := New(30*1000*1000, 30*1000*1000) // 30ms capacity |
| 136 | + |
| 137 | + ctx := context.Background() |
| 138 | + |
| 139 | + // Take most tokens |
| 140 | + tokens1, err := h.Wait(ctx, "user1", 25*1000*1000) // 25ms |
| 141 | + require.NoError(t, err) |
| 142 | + |
| 143 | + // Launch multiple waiters |
| 144 | + numWaiters := 5 |
| 145 | + waiters := make([]chan bool, numWaiters) |
| 146 | + for i := 0; i < numWaiters; i++ { |
| 147 | + waiters[i] = make(chan bool, 1) |
| 148 | + go func(id int, done chan bool) { |
| 149 | + ctx2, cancel := context.WithTimeout(context.Background(), 3*time.Second) |
| 150 | + defer cancel() |
| 151 | + |
| 152 | + userID := "waiter" + string(rune('A'+id)) |
| 153 | + tokens, err := h.Wait(ctx2, userID, 5*1000*1000) // 5ms each |
| 154 | + if err != nil { |
| 155 | + t.Errorf("Waiter %d failed: %v", id, err) |
| 156 | + done <- false |
| 157 | + return |
| 158 | + } |
| 159 | + h.Return(tokens) |
| 160 | + done <- true |
| 161 | + }(i, waiters[i]) |
| 162 | + } |
| 163 | + |
| 164 | + // Wait for all waiters to be queued |
| 165 | + time.Sleep(100 * time.Millisecond) |
| 166 | + |
| 167 | + // Return tokens gradually |
| 168 | + tokens1.Use(10 * 1000 * 1000) // Use 10ms |
| 169 | + h.Return(tokens1) // Return 15ms |
| 170 | + |
| 171 | + // All waiters should eventually proceed |
| 172 | + timeout := time.After(2 * time.Second) |
| 173 | + for i := 0; i < numWaiters; i++ { |
| 174 | + select { |
| 175 | + case success := <-waiters[i]: |
| 176 | + if !success { |
| 177 | + t.Fatalf("Waiter %d failed", i) |
| 178 | + } |
| 179 | + case <-timeout: |
| 180 | + t.Fatalf("Waiter %d timed out - not all waiters were woken up", i) |
| 181 | + } |
| 182 | + } |
| 183 | +} |
| 184 | + |
| 185 | +// TestReturnToParentWakesAllChildren tests that returning to parent processes all children |
| 186 | +func TestReturnToParentWakesAllChildren(t *testing.T) { |
| 187 | + h := New(20*1000*1000, 20*1000*1000) // 20ms capacity |
| 188 | + |
| 189 | + ctx := context.Background() |
| 190 | + |
| 191 | + // Create multiple users and exhaust their tokens |
| 192 | + tokens1, _ := h.Wait(ctx, "user1", 7*1000*1000) |
| 193 | + tokens2, _ := h.Wait(ctx, "user2", 7*1000*1000) |
| 194 | + tokens3, _ := h.Wait(ctx, "user3", 6*1000*1000) |
| 195 | + |
| 196 | + // Now all children should have minimal tokens |
| 197 | + // Launch waiters for different users |
| 198 | + var wg sync.WaitGroup |
| 199 | + for i := 0; i < 3; i++ { |
| 200 | + wg.Add(1) |
| 201 | + go func(id int) { |
| 202 | + defer wg.Done() |
| 203 | + userID := "user" + string(rune('1'+id)) |
| 204 | + ctx2, cancel := context.WithTimeout(context.Background(), 2*time.Second) |
| 205 | + defer cancel() |
| 206 | + |
| 207 | + tokens, err := h.Wait(ctx2, userID, 5*1000*1000) // 5ms |
| 208 | + if err != nil { |
| 209 | + t.Errorf("User %d wait failed: %v", id+1, err) |
| 210 | + return |
| 211 | + } |
| 212 | + h.Return(tokens) |
| 213 | + }(i) |
| 214 | + } |
| 215 | + |
| 216 | + // Wait for waiters to queue |
| 217 | + time.Sleep(100 * time.Millisecond) |
| 218 | + |
| 219 | + // Return tokens from user1 - with overflow, should go to parent |
| 220 | + // and wake up waiters for all users |
| 221 | + tokens1.Use(2 * 1000 * 1000) |
| 222 | + h.Return(tokens1) // Returns 5ms, should overflow to parent |
| 223 | + |
| 224 | + tokens2.Use(2 * 1000 * 1000) |
| 225 | + h.Return(tokens2) |
| 226 | + |
| 227 | + tokens3.Use(2 * 1000 * 1000) |
| 228 | + h.Return(tokens3) |
| 229 | + |
| 230 | + // All waiters should complete |
| 231 | + done := make(chan struct{}) |
| 232 | + go func() { |
| 233 | + wg.Wait() |
| 234 | + close(done) |
| 235 | + }() |
| 236 | + |
| 237 | + select { |
| 238 | + case <-done: |
| 239 | + // Success |
| 240 | + case <-time.After(2 * time.Second): |
| 241 | + t.Fatal("Not all waiters were woken up when tokens returned to parent") |
| 242 | + } |
| 243 | +} |
0 commit comments