|
| 1 | +package fs2 |
| 2 | +package interop.cats |
| 3 | + |
| 4 | +import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} |
| 5 | + |
| 6 | +import fs2.internal.{Actor, LinkedMap} |
| 7 | +import fs2.util.{Async, Attempt, Effect, Free} |
| 8 | + |
| 9 | +import _root_.cats.effect.IO |
| 10 | + |
| 11 | +import scala.concurrent.ExecutionContext |
| 12 | + |
| 13 | +// mostly cribbed from fs2-scalaz:TaskAsyncInstances.scala |
| 14 | +trait IOAsyncInstances { |
| 15 | + import IOAsyncInstances._ |
| 16 | + |
| 17 | + protected class EffectIO extends Effect[IO] { |
| 18 | + def pure[A](a: A) = IO.pure(a) |
| 19 | + def flatMap[A,B](a: IO[A])(f: A => IO[B]): IO[B] = a flatMap f |
| 20 | + override def delay[A](a: => A) = IO(a) |
| 21 | + def suspend[A](fa: => IO[A]) = IO.suspend(fa) |
| 22 | + def fail[A](err: Throwable) = IO.raiseError(err) |
| 23 | + def attempt[A](t: IO[A]) = t.attempt |
| 24 | + def unsafeRunAsync[A](t: IO[A])(cb: Attempt[A] => Unit): Unit = t.unsafeRunAsync(cb) |
| 25 | + override def toString = "Effect[IO]" |
| 26 | + } |
| 27 | + |
| 28 | + implicit def asyncInstance(implicit ec: ExecutionContext): Async[IO] = new EffectIO with Async[IO] { |
| 29 | + def ref[A]: IO[Async.Ref[IO, A]] = CatsIO.ref[A](ec) |
| 30 | + override def toString = "Async[IO]" |
| 31 | + } |
| 32 | + |
| 33 | + /* |
| 34 | + * Implementation is taken from `fs2` library, with only minor changes. See: |
| 35 | + * |
| 36 | + * https://github.com/functional-streams-for-scala/fs2/blob/v0.9.0-M2/core/src/main/scala/fs2/util/IO.scala |
| 37 | + * |
| 38 | + * Copyright (c) 2013 Paul Chiusano, and respective contributors |
| 39 | + * |
| 40 | + * and is licensed MIT, see LICENSE file at: |
| 41 | + * |
| 42 | + * https://github.com/functional-streams-for-scala/fs2/blob/series/0.9/LICENSE |
| 43 | + */ |
| 44 | + private[fs2] object CatsIO { |
| 45 | + private type Callback[A] = Either[Throwable, A] => Unit |
| 46 | + |
| 47 | + private trait MsgId |
| 48 | + private trait Msg[A] |
| 49 | + private object Msg { |
| 50 | + case class Read[A](cb: Callback[(A, Long)], id: MsgId) extends Msg[A] |
| 51 | + case class Nevermind[A](id: MsgId, cb: Callback[Boolean]) extends Msg[A] |
| 52 | + case class Set[A](r: Either[Throwable, A]) extends Msg[A] |
| 53 | + case class TrySet[A](id: Long, r: Either[Throwable, A], |
| 54 | + cb: Callback[Boolean]) extends Msg[A] |
| 55 | + } |
| 56 | + |
| 57 | + def ref[A](implicit ec: ExecutionContext): IO[Ref[A]] = IO { |
| 58 | + implicit val S = Strategy.fromExecutionContext(ec) |
| 59 | + |
| 60 | + var result: Either[Throwable, A] = null |
| 61 | + // any waiting calls to `access` before first `set` |
| 62 | + var waiting: LinkedMap[MsgId, Callback[(A, Long)]] = LinkedMap.empty |
| 63 | + // id which increases with each `set` or successful `modify` |
| 64 | + var nonce: Long = 0 |
| 65 | + |
| 66 | + lazy val actor: Actor[Msg[A]] = Actor.actor[Msg[A]] { |
| 67 | + case Msg.Read(cb, idf) => |
| 68 | + if (result eq null) waiting = waiting.updated(idf, cb) |
| 69 | + else { val r = result; val id = nonce; ec { cb(r.right.map((_,id))) }; () } |
| 70 | + |
| 71 | + case Msg.Set(r) => |
| 72 | + nonce += 1L |
| 73 | + if (result eq null) { |
| 74 | + val id = nonce |
| 75 | + waiting.values.foreach(cb => ec { cb(r.right.map((_,id))) }) |
| 76 | + waiting = LinkedMap.empty |
| 77 | + } |
| 78 | + result = r |
| 79 | + |
| 80 | + case Msg.TrySet(id, r, cb) => |
| 81 | + if (id == nonce) { |
| 82 | + nonce += 1L; val id2 = nonce |
| 83 | + waiting.values.foreach(cb => ec { cb(r.right.map((_,id2))) }) |
| 84 | + waiting = LinkedMap.empty |
| 85 | + result = r |
| 86 | + cb(Right(true)) |
| 87 | + } |
| 88 | + else cb(Right(false)) |
| 89 | + |
| 90 | + case Msg.Nevermind(id, cb) => |
| 91 | + val interrupted = waiting.get(id).isDefined |
| 92 | + waiting = waiting - id |
| 93 | + val _ = ec { cb (Right(interrupted)) } |
| 94 | + } |
| 95 | + |
| 96 | + new Ref(actor) |
| 97 | + } |
| 98 | + |
| 99 | + class Ref[A] private[fs2](actor: Actor[Msg[A]])(implicit ec: ExecutionContext, protected val F: Async[IO]) extends Async.Ref[IO,A] { |
| 100 | + |
| 101 | + def access: IO[(A, Either[Throwable,A] => IO[Boolean])] = |
| 102 | + IO(new MsgId {}).flatMap { mid => |
| 103 | + getStamped(mid).map { case (a, id) => |
| 104 | + val set = (a: Either[Throwable,A]) => |
| 105 | + IO.async[Boolean] { cb => actor ! Msg.TrySet(id, a, cb) } |
| 106 | + (a, set) |
| 107 | + } |
| 108 | + } |
| 109 | + |
| 110 | + /** |
| 111 | + * Return a `IO` that submits `t` to this ref for evaluation. |
| 112 | + * When it completes it overwrites any previously `put` value. |
| 113 | + */ |
| 114 | + def set(t: IO[A]): IO[Unit] = |
| 115 | + IO { ec { t.unsafeRunAsync { r => actor ! Msg.Set(r) } }; () } |
| 116 | + def setFree(t: Free[IO,A]): IO[Unit] = |
| 117 | + set(t.run(F)) |
| 118 | + def runSet(e: Either[Throwable,A]): Unit = |
| 119 | + actor ! Msg.Set(e) |
| 120 | + |
| 121 | + private def getStamped(msg: MsgId): IO[(A,Long)] = |
| 122 | + IO.async[(A,Long)] { cb => actor ! Msg.Read(cb, msg) } |
| 123 | + |
| 124 | + /** Return the most recently completed `set`, or block until a `set` value is available. */ |
| 125 | + override def get: IO[A] = IO(new MsgId {}).flatMap { mid => getStamped(mid).map(_._1) } |
| 126 | + |
| 127 | + /** Like `get`, but returns a `IO[Unit]` that can be used cancel the subscription. */ |
| 128 | + def cancellableGet: IO[(IO[A], IO[Unit])] = IO { |
| 129 | + val id = new MsgId {} |
| 130 | + val get = getStamped(id).map(_._1) |
| 131 | + val cancel = IO.async[Unit] { |
| 132 | + cb => actor ! Msg.Nevermind(id, r => cb(r.right.map(_ => ()))) |
| 133 | + } |
| 134 | + (get, cancel) |
| 135 | + } |
| 136 | + |
| 137 | + /** |
| 138 | + * Runs `t1` and `t2` simultaneously, but only the winner gets to |
| 139 | + * `set` to this `ref`. The loser continues running but its reference |
| 140 | + * to this ref is severed, allowing this ref to be garbage collected |
| 141 | + * if it is no longer referenced by anyone other than the loser. |
| 142 | + */ |
| 143 | + def setRace(t1: IO[A], t2: IO[A]): IO[Unit] = IO { |
| 144 | + val ref = new AtomicReference(actor) |
| 145 | + val won = new AtomicBoolean(false) |
| 146 | + val win = (res: Either[Throwable, A]) => { |
| 147 | + // important for GC: we don't reference this ref |
| 148 | + // or the actor directly, and the winner destroys any |
| 149 | + // references behind it! |
| 150 | + if (won.compareAndSet(false, true)) { |
| 151 | + val actor = ref.get |
| 152 | + ref.set(null) |
| 153 | + actor ! Msg.Set(res) |
| 154 | + } |
| 155 | + } |
| 156 | + t1.shift.unsafeRunAsync(win) |
| 157 | + t2.shift.unsafeRunAsync(win) |
| 158 | + } |
| 159 | + } |
| 160 | + |
| 161 | + } |
| 162 | +} |
| 163 | + |
| 164 | +private[fs2] object IOAsyncInstances { |
| 165 | + private implicit final class ECSyntax(val ec: ExecutionContext) extends AnyVal { |
| 166 | + def apply[A](thunk: => A): Unit = |
| 167 | + ec.execute(new Runnable { def run() = { thunk; () } }) |
| 168 | + } |
| 169 | +} |
0 commit comments