1- @file:Suppress(" EXPERIMENTAL_FEATURE_WARNING" )
1+ @file:Suppress(" EXPERIMENTAL_FEATURE_WARNING" , " UNCHECKED_CAST " )
22
33package com.github.h0tk3y.kotlinMonads
44
5+ import kotlinx.coroutines.experimental.Here
56import java.io.Serializable
67import java.util.*
78import kotlin.coroutines.Continuation
8- import kotlin.coroutines.CoroutineIntrinsics
9+ import kotlin.coroutines.CoroutineContext
10+ import kotlin.coroutines.RestrictsSuspension
11+ import kotlin.coroutines.intrinsics.SUSPENDED_MARKER
12+ import kotlin.coroutines.intrinsics.suspendCoroutineOrReturn
913import kotlin.coroutines.startCoroutine
1014
1115fun <M : Monad <M , * >, T > doWith (m : Monad <M , T >,
12- c : suspend DoController <M , T >.() -> Unit ): Monad <M , T > =
16+ c : suspend DoController <M , T >.(T ) -> Unit ): Monad <M , T > =
1317 m.bind { t -> doWith(this , t, c) }
1418
15-
1619fun <M : Monad <M , * >, T > doWith (aReturn : Return <M >,
1720 defaultValue : T ,
18- c : suspend DoController <M , T >.() -> Unit ): Monad <M , T > {
21+ c : suspend DoController <M , T >.(T ) -> Unit ): Monad <M , T > {
1922 val controller = DoController (aReturn, defaultValue)
20- c.startCoroutine(controller, object : Continuation <Unit > {
23+ val f: suspend DoController <M , T >.() -> Unit = { c(defaultValue) }
24+ f.startCoroutine(controller, object : Continuation <Unit > {
25+ override fun resumeWithException (exception : Throwable ) {}
2126 override fun resume (value : Unit ) {}
22- override fun resumeWithException ( exception : Throwable ) = throw exception
27+ override val context : CoroutineContext = Here
2328 })
2429 return controller.lastResult
2530}
2631
27- val labelField by lazy {
28- val jClass = Class .forName(" kotlin.jvm.internal.RestrictedCoroutineImpl " )
32+ private val labelField by lazy {
33+ val jClass = Class .forName(" kotlin.jvm.internal.CoroutineImpl " )
2934 return @lazy jClass.getDeclaredField(" label" ).apply { isAccessible = true }
3035}
3136
32- var <T > Continuation <T >.label
33- get() = labelField.get(this )
34- set(value) = labelField.set(this @label, value)
37+ private val innerContinuationField by lazy {
38+ val jClass = Class .forName(" kotlinx.coroutines.experimental.DispatchedContinuation" )
39+ return @lazy jClass.getDeclaredField(" continuation" ).apply { isAccessible = true }
40+ }
41+
42+ private var <T > Continuation <T >.label
43+ get() = labelField.get(innerContinuationField.get(this ))
44+ set(value) = labelField.set(innerContinuationField.get(this @label), value)
3545
3646private fun <T , R > backupLabel (c : Continuation <T >, block : Continuation <T >.() -> R ): R {
3747 val backupLabel = c.label
@@ -40,14 +50,15 @@ private fun <T, R> backupLabel(c: Continuation<T>, block: Continuation<T>.() ->
4050 return r
4151}
4252
53+ @RestrictsSuspension
4354class DoController <M : Monad <M , * >, T >(val returning : Return <M >,
4455 val value : T ) : Serializable, Return<M> by returning {
4556 var lastResult: Monad <M , T > = returning.returns(value)
4657 internal set
4758
4859 private val stackSignals = Stack <Boolean >().apply { push(false ) }
4960
50- suspend fun bind (m : Monad <M , T >): T = CoroutineIntrinsics . suspendCoroutineOrReturn { c ->
61+ suspend fun bind (m : Monad <M , T >): T = suspendCoroutineOrReturn { c ->
5162 stackSignals.pop()
5263 stackSignals.push(true )
5364 var anyCont = false
@@ -65,10 +76,10 @@ class DoController<M : Monad<M, *>, T>(val returning: Return<M>,
6576 }
6677 }
6778 lastResult = if (anyCont) o else m
68- CoroutineIntrinsics . SUSPENDED
79+ SUSPENDED_MARKER
6980 }
7081
71- suspend fun then (m : Monad <M , T >) = CoroutineIntrinsics . suspendCoroutineOrReturn<Unit > { c ->
82+ suspend fun then (m : Monad <M , T >) = suspendCoroutineOrReturn<Unit > { c ->
7283 stackSignals.pop()
7384 stackSignals.push(true )
7485 var anyCont = false
@@ -84,15 +95,6 @@ class DoController<M : Monad<M, *>, T>(val returning: Return<M>,
8495 }
8596 }
8697 lastResult = if (anyCont) o else m
87- CoroutineIntrinsics .SUSPENDED
88- }
89- }
90-
91- fun main (args : Array <String >) {
92- val m = doWith(monadListOf(0 )) {
93- val x = bind(monadListOf(1 , 2 , 3 ))
94- val y = bind(monadListOf(x, x))
95- then(monadListOf(y, y + 1 ))
98+ SUSPENDED_MARKER
9699 }
97- println (m)
98100}
0 commit comments