@@ -27,11 +27,12 @@ import java.util.concurrent.ConcurrentHashMap
27
27
import java.util.concurrent.TimeUnit
28
28
import java.util.concurrent.locks.LockSupport
29
29
30
+ private const val WAIT_LOST_THREADS = 10_000L // 10s
30
31
private val ignoreLostThreads = mutableSetOf<String >()
31
32
32
33
fun ignoreLostThreads (vararg s : String ) { ignoreLostThreads + = s }
33
34
34
- fun threadNames (): Set <String > {
35
+ fun currentThreads (): Set <Thread > {
35
36
var estimate = 0
36
37
while (true ) {
37
38
estimate = estimate.coerceAtLeast(Thread .activeCount() + 1 )
@@ -41,33 +42,37 @@ fun threadNames(): Set<String> {
41
42
estimate = n + 1
42
43
continue // retry with a better size estimate
43
44
}
44
- val names = hashSetOf<String >()
45
+ val threads = hashSetOf<Thread >()
45
46
for (i in 0 until n)
46
- names .add(sanitizeThreadName( arrayOfThreads[i]!! .name) )
47
- return names
47
+ threads .add(arrayOfThreads[i]!! )
48
+ return threads
48
49
}
49
50
}
50
51
51
- // remove coroutine names from thread in case we have lost threads with coroutines running in them
52
- private fun sanitizeThreadName (name : String ): String {
53
- val i = name.indexOf(" @" )
54
- return if (i < 0 ) name else name.substring(0 , i)
55
- }
56
-
57
- fun checkTestThreads (threadNamesBefore : Set <String >) {
52
+ fun checkTestThreads (threadsBefore : Set <Thread >) {
58
53
// give threads some time to shutdown
59
- val waitTill = System .currentTimeMillis() + 1000L
60
- var diff: List <String >
54
+ val waitTill = System .currentTimeMillis() + WAIT_LOST_THREADS
55
+ var diff: List <Thread >
61
56
do {
62
- val threadNamesAfter = threadNames ()
63
- diff = (threadNamesAfter - threadNamesBefore ).filter { name ->
64
- ignoreLostThreads.none { prefix -> name.startsWith(prefix) }
57
+ val threadsAfter = currentThreads ()
58
+ diff = (threadsAfter - threadsBefore ).filter { thread ->
59
+ ignoreLostThreads.none { prefix -> thread. name.startsWith(prefix) }
65
60
}
66
61
if (diff.isEmpty()) break
67
62
} while (System .currentTimeMillis() <= waitTill)
68
63
ignoreLostThreads.clear()
69
- diff.forEach { println (" Lost thread '$it '" ) }
70
- check(diff.isEmpty()) { " Lost ${diff.size} threads" }
64
+ if (diff.isEmpty()) return
65
+ val message = " Lost threads ${diff.map { it.name }} "
66
+ println (" !!! $message " )
67
+ println (" === Dumping lost thread stack traces" )
68
+ diff.forEach { thread ->
69
+ println (" Thread \" ${thread.name} \" ${thread.state} " )
70
+ val trace = thread.stackTrace
71
+ for (t in trace) println (" \t at ${t.className} .${t.methodName} (${t.fileName} :${t.lineNumber} )" )
72
+ println ()
73
+ }
74
+ println (" ===" )
75
+ error(message)
71
76
}
72
77
73
78
fun trackTask (block : Runnable ) = timeSource.trackTask(block)
@@ -96,7 +101,7 @@ fun test(name: String, block: () -> Unit): List<String> = outputException(name)
96
101
resetCoroutineId()
97
102
// shutdown execution with old time source (in case it was working)
98
103
DefaultExecutor .shutdown(SHUTDOWN_TIMEOUT )
99
- val threadNamesBefore = threadNames ()
104
+ val threadsBefore = currentThreads ()
100
105
val testTimeSource = TestTimeSource (oldOut)
101
106
timeSource = testTimeSource
102
107
DefaultExecutor .ensureStarted() // should start with new time source
@@ -121,7 +126,7 @@ fun test(name: String, block: () -> Unit): List<String> = outputException(name)
121
126
oldOut.println (" --- done" )
122
127
System .setOut(oldOut)
123
128
System .setErr(oldErr)
124
- checkTestThreads(threadNamesBefore )
129
+ checkTestThreads(threadsBefore )
125
130
}
126
131
return ByteArrayInputStream (bytes).bufferedReader().readLines()
127
132
}
0 commit comments