Skip to content

Commit 98b04d4

Browse files
committed
MVP of new coroutine scheduler
1 parent fc87803 commit 98b04d4

File tree

11 files changed

+720
-0
lines changed

11 files changed

+720
-0
lines changed

benchmarks/src/jmh/kotlin/benchmarks/ParametrizedDispatcherBase.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package benchmarks
22

3+
import benchmarks.actors.CORES_COUNT
34
import kotlinx.coroutines.experimental.CommonPool
45
import kotlinx.coroutines.experimental.ThreadPoolDispatcher
56
import kotlinx.coroutines.experimental.newFixedThreadPoolContext
7+
import kotlinx.coroutines.experimental.scheduling.ExperimentalCoroutineDispatcher
68
import org.openjdk.jmh.annotations.Param
79
import org.openjdk.jmh.annotations.Setup
810
import org.openjdk.jmh.annotations.TearDown
@@ -23,6 +25,9 @@ abstract class ParametrizedDispatcherBase {
2325
open fun setup() {
2426
benchmarkContext = when {
2527
dispatcher == "fjp" -> CommonPool
28+
dispatcher == "experimental" -> {
29+
ExperimentalCoroutineDispatcher(CORES_COUNT)
30+
}
2631
dispatcher.startsWith("ftp") -> {
2732
newFixedThreadPoolContext(dispatcher.substring(4).toInt(), dispatcher).also { closeable = it }
2833
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package kotlinx.coroutines.experimental.scheduling
2+
3+
import kotlinx.coroutines.experimental.Runnable
4+
import java.io.Closeable
5+
import java.util.*
6+
import java.util.concurrent.ConcurrentLinkedQueue
7+
import java.util.concurrent.Executor
8+
import java.util.concurrent.locks.LockSupport
9+
10+
/**
11+
* TODO design rationale
12+
*/
13+
class CoroutineScheduler(private val corePoolSize: Int) : Executor, Closeable {
14+
15+
private val workers: Array<PoolWorker>
16+
private val globalWorkQueue: Queue<Task> = ConcurrentLinkedQueue<Task>()
17+
@Volatile
18+
private var isClosed = false
19+
20+
init {
21+
require(corePoolSize >= 1, { "Expected positive core pool size, but was $corePoolSize" })
22+
workers = Array(corePoolSize, { PoolWorker(it) })
23+
workers.forEach { it.start() }
24+
}
25+
26+
override fun execute(command: Runnable) = dispatch(command)
27+
28+
override fun close() {
29+
isClosed = true
30+
}
31+
32+
fun dispatch(command: Runnable, intensive: Boolean = false) {
33+
val task = TimedTask(System.nanoTime(), command)
34+
if (!submitToLocalQueue(task, intensive)) {
35+
globalWorkQueue.add(task)
36+
}
37+
}
38+
39+
private fun submitToLocalQueue(task: Task, intensive: Boolean): Boolean {
40+
val worker = Thread.currentThread() as? PoolWorker ?: return false
41+
if (intensive && worker.localQueue.bufferSize > FORKED_TASK_OFFLOAD_THRESHOLD) return false
42+
worker.localQueue.offer(task, globalWorkQueue)
43+
return true
44+
}
45+
46+
private inner class PoolWorker(index: Int) : Thread("CoroutinesScheduler-worker-$index") {
47+
init {
48+
isDaemon = true
49+
}
50+
51+
val localQueue: WorkQueue = WorkQueue()
52+
53+
@Volatile
54+
var yields = 0
55+
56+
override fun run() {
57+
while (!isClosed) {
58+
try {
59+
val job = findTask()
60+
if (job == null) {
61+
awaitWork()
62+
} else {
63+
yields = 0
64+
job.task.run()
65+
}
66+
} catch (e: Throwable) {
67+
println(e) // TODO handler
68+
}
69+
}
70+
}
71+
72+
private fun awaitWork() {
73+
// Temporary solution
74+
if (++yields > 100000) {
75+
LockSupport.parkNanos(WORK_STEALING_TIME_RESOLUTION / 2)
76+
}
77+
}
78+
79+
private fun findTask(): Task? {
80+
// TODO explain, probabilistic check with park counter?
81+
var task: Task? = globalWorkQueue.poll()
82+
if (task != null) return task
83+
84+
task = localQueue.poll()
85+
if (task != null) return task
86+
87+
return trySteal()
88+
}
89+
90+
private fun trySteal(): Task? {
91+
if (corePoolSize == 1) {
92+
return null
93+
}
94+
95+
while (true) {
96+
val worker = workers[RANDOM_PROVIDER().nextInt(workers.size)]
97+
if (worker !== this) {
98+
worker.localQueue.offloadWork(true) {
99+
localQueue.offer(it, globalWorkQueue)
100+
}
101+
102+
return localQueue.poll()
103+
}
104+
}
105+
}
106+
}
107+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package kotlinx.coroutines.experimental.scheduling
2+
3+
import kotlinx.coroutines.experimental.*
4+
import java.io.Closeable
5+
import java.util.concurrent.TimeUnit
6+
import kotlin.coroutines.experimental.AbstractCoroutineContextElement
7+
import kotlin.coroutines.experimental.CoroutineContext
8+
9+
/**
10+
* Unstable API and subject to change.
11+
* Context marker which gives scheduler a hint that submitted jobs can be distributed among cores aggressively.
12+
* Usually it's useful for massive jobs submission produced by single coroutine, e.g. data intensive fork-join tasks
13+
* or fan-out notifications for a large number of listeners.
14+
*/
15+
object ForkedMarker : AbstractCoroutineContextElement(ForkedKey)
16+
17+
private object ForkedKey : CoroutineContext.Key<ForkedMarker>
18+
19+
class ExperimentalCoroutineDispatcher(threads: Int = Runtime.getRuntime().availableProcessors()) : CoroutineDispatcher(), Delay, Closeable {
20+
21+
private val coroutineScheduler = CoroutineScheduler(threads)
22+
23+
override fun dispatch(context: CoroutineContext, block: Runnable) {
24+
coroutineScheduler.dispatch(block, context[ForkedKey] != null)
25+
}
26+
27+
override fun scheduleResumeAfterDelay(time: Long, unit: TimeUnit, continuation: CancellableContinuation<Unit>) =
28+
DefaultExecutor.scheduleResumeAfterDelay(time, unit, continuation)
29+
30+
override fun close() = coroutineScheduler.close()
31+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package kotlinx.coroutines.experimental.scheduling
2+
3+
import java.util.*
4+
5+
private val RANDOM = object : ThreadLocal<Random>() {
6+
override fun initialValue() = Random()
7+
}
8+
9+
// Dynamic discovery is not yet supported
10+
val RANDOM_PROVIDER = { RANDOM.get() }
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package kotlinx.coroutines.experimental.scheduling
2+
3+
import java.util.*
4+
5+
internal typealias Task = TimedTask
6+
internal typealias GlobalQueue = Queue<Task>
7+
8+
internal val WORK_STEALING_TIME_RESOLUTION = readFromSystemProperties(
9+
"kotlinx.coroutines.scheduler.resolution.us", 500L, String::toLongOrNull)
10+
11+
internal val FORKED_TASK_OFFLOAD_THRESHOLD = readFromSystemProperties(
12+
"kotlinx.coroutines.scheduler.fork.threshold", 64L, String::toLongOrNull)
13+
14+
internal var schedulerTimeSource: TimeSource = NanoTimeSource
15+
16+
internal data class TimedTask(val submissionTime: Long, val task: Runnable)
17+
18+
internal abstract class TimeSource {
19+
abstract fun nanoTime(): Long
20+
}
21+
22+
internal object NanoTimeSource : TimeSource() {
23+
override fun nanoTime() = System.nanoTime()
24+
}
25+
26+
private fun <T> readFromSystemProperties(propertyName: String, defaultValue: T, parser: (String) -> T?): T {
27+
val value = try {
28+
System.getProperty(propertyName)
29+
} catch (e: SecurityException) {
30+
null
31+
} ?: return defaultValue
32+
33+
val parsed = parser(value)
34+
return parsed ?: error("System property '$propertyName' has unrecognized value '$value'")
35+
}
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
package kotlinx.coroutines.experimental.scheduling
2+
3+
import java.util.concurrent.atomic.AtomicInteger
4+
import java.util.concurrent.atomic.AtomicReference
5+
import java.util.concurrent.atomic.AtomicReferenceArray
6+
7+
internal const val BUFFER_CAPACITY_BASE = 7
8+
internal const val BUFFER_CAPACITY = 1 shl BUFFER_CAPACITY_BASE
9+
internal const val MASK = BUFFER_CAPACITY - 1 // 128 by default
10+
11+
/**
12+
* Unstable API and subject to change.
13+
* Tightly coupled with [CoroutineScheduler] queue of pending tasks, but extracted to separate file for simplicity.
14+
* At any moment queue is used only by [CoroutineScheduler.PoolWorker] threads, has only one producer (worker owning this queue)
15+
* and any amount of consumers, other pool workers which are trying to steal work.
16+
*
17+
* Fairness
18+
* [WorkQueue] provides semi-FIFO order, but with priority for most recently submitted task assuming
19+
* that these two (current and submitted) are communicating and sharing state thus making such communication extremely fast.
20+
* E.g. submitted jobs [1, 2, 3, 4] will be executed in [4, 1, 2, 3] order.
21+
*
22+
* Work offloading
23+
* When queue is full, half of existing tasks is offloaded to global queue which is regularly polled by other pool workers.
24+
* Offloading occurs in LIFO order for the sake of implementation simplicity: offload should be extremely rare and occurs only in specific use-cases
25+
* (e.g. when coroutine starts heavy fork-join-like computation), so fairness is not important.
26+
* As an alternative, offloading directly to some [CoroutineScheduler.PoolWorker] may be used, but then strategy of selecting most idle worker
27+
* should be implemented and implementation should be aware multiple producers.
28+
*/
29+
internal class WorkQueue {
30+
31+
internal val bufferSize: Int get() = producerIndex.get() - consumerIndex.get()
32+
private val buffer: AtomicReferenceArray<Task?> = AtomicReferenceArray(BUFFER_CAPACITY)
33+
private val lastScheduledTask: AtomicReference<Task?> = AtomicReference(null)
34+
35+
private val producerIndex: AtomicInteger = AtomicInteger(0)
36+
private val consumerIndex: AtomicInteger = AtomicInteger(0)
37+
38+
/**
39+
* Retrieves and removes task from head of the queue
40+
* Invariant: this method is called only by owner of the queue ([pollExternal] is not)
41+
*/
42+
fun poll(): Task? {
43+
return lastScheduledTask.getAndSet(null) ?: pollExternal()
44+
}
45+
46+
/**
47+
* Invariant: this method is called only by owner of the queue
48+
* @param task task to put into local queue
49+
* @param globalQueue fallback queue which is used when local queue is overflown
50+
*/
51+
fun offer(task: Task, globalQueue: GlobalQueue) {
52+
while (true) {
53+
val previous = lastScheduledTask.get()
54+
if (lastScheduledTask.compareAndSet(previous, task)) {
55+
if (previous != null) {
56+
addLast(previous, globalQueue)
57+
}
58+
return
59+
}
60+
}
61+
}
62+
63+
/**
64+
* Offloads half of the current buffer to [sink]
65+
* @param byTimer whether task deadline should be checked before offloading
66+
*/
67+
inline fun offloadWork(byTimer: Boolean, sink: (Task) -> Unit) {
68+
repeat((bufferSize / 2).coerceAtLeast(1)) {
69+
if (bufferSize == 0) { // try to steal head if buffer is empty
70+
val lastScheduled = lastScheduledTask.get() ?: return
71+
if (!byTimer || schedulerTimeSource.nanoTime() - lastScheduled.submissionTime < WORK_STEALING_TIME_RESOLUTION) {
72+
return
73+
}
74+
75+
if (lastScheduledTask.compareAndSet(lastScheduled, null)) {
76+
sink(lastScheduled)
77+
return
78+
}
79+
}
80+
81+
// TODO use batch drain and (if target queue allows) batch insert
82+
val task = pollExternal { !byTimer || schedulerTimeSource.nanoTime() - it.submissionTime >= WORK_STEALING_TIME_RESOLUTION }
83+
?: return
84+
sink(task)
85+
}
86+
}
87+
88+
/**
89+
* [poll] for external (not owning this queue) workers
90+
*/
91+
private inline fun pollExternal(predicate: (Task) -> Boolean = { true }): Task? {
92+
while (true) {
93+
val tailLocal = consumerIndex.get()
94+
if (tailLocal - producerIndex.get() == 0) return null
95+
val index = tailLocal and MASK
96+
val element = buffer[index] ?: continue
97+
if (!predicate(element)) {
98+
return null
99+
}
100+
101+
if (consumerIndex.compareAndSet(tailLocal, tailLocal + 1)) {
102+
// 1) Help GC 2) Signal producer that this slot is consumed and may be used
103+
return buffer.getAndSet(index, null)
104+
}
105+
}
106+
}
107+
108+
// Called only by owner
109+
private fun addLast(task: Task, globalQueue: GlobalQueue) {
110+
while (!tryAddLast(task)) {
111+
offloadWork(false) {
112+
globalQueue.add(it)
113+
}
114+
}
115+
}
116+
117+
// Called only by owner
118+
private fun tryAddLast(task: Task): Boolean {
119+
if (bufferSize == BUFFER_CAPACITY - 1) return false
120+
val headLocal = producerIndex.get()
121+
val nextIndex = headLocal and MASK
122+
123+
/*
124+
* If current element is not null then we're racing with consumers for tail. If we skip this check then
125+
* consumer can null out current element and it will be lost. If we're racing for tail then
126+
* queue is close to overflow => it's fine to offload work to global queue
127+
*/
128+
if (buffer[nextIndex] != null) {
129+
return false
130+
}
131+
132+
buffer.lazySet(nextIndex, task)
133+
producerIndex.incrementAndGet()
134+
return true
135+
}
136+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package kotlinx.coroutines.experimental.scheduling
2+
3+
import kotlinx.coroutines.experimental.TestBase
4+
import org.junit.After
5+
import org.junit.Test
6+
import java.util.concurrent.ConcurrentHashMap
7+
import java.util.concurrent.CountDownLatch
8+
import java.util.concurrent.atomic.AtomicInteger
9+
import kotlin.coroutines.experimental.CoroutineContext
10+
import kotlin.coroutines.experimental.EmptyCoroutineContext
11+
import kotlin.test.assertEquals
12+
13+
class CoroutineSchedulerStressTest : TestBase() {
14+
15+
private var dispatcher: ExperimentalCoroutineDispatcher = ExperimentalCoroutineDispatcher()
16+
private val observedThreads = ConcurrentHashMap<Thread, MutableSet<Int>>()
17+
private val tasksNum = 1_000_000
18+
private val processed = AtomicInteger(0)
19+
20+
@After
21+
fun tearDown() {
22+
dispatcher.close()
23+
}
24+
25+
@Test
26+
fun submitTasks() {
27+
stressTest(ForkedMarker)
28+
}
29+
30+
@Test
31+
fun submitTasksForked() {
32+
stressTest(EmptyCoroutineContext)
33+
}
34+
35+
private fun stressTest(ctx: CoroutineContext) {
36+
val finishLatch = CountDownLatch(1)
37+
38+
for (i in 1..tasksNum) {
39+
dispatcher.dispatch(ctx, Runnable {
40+
val numbers = observedThreads.computeIfAbsent(Thread.currentThread(), { _ -> hashSetOf() })
41+
require(numbers.add(i))
42+
if (processed.incrementAndGet() == tasksNum) {
43+
finishLatch.countDown()
44+
}
45+
})
46+
}
47+
48+
finishLatch.await()
49+
assertEquals(Runtime.getRuntime().availableProcessors(), observedThreads.size)
50+
val result = observedThreads.values.flatMap { it }.toSet()
51+
assertEquals((1..tasksNum).toSet(), result)
52+
}
53+
}

0 commit comments

Comments
 (0)