From 11c0c6e1d5ed3fb569a4c7c73b9be2ed3073c75c Mon Sep 17 00:00:00 2001 From: Lewin Bormann Date: Wed, 12 Jul 2023 22:42:52 +0200 Subject: [PATCH 1/5] Implement simple atomic stream select for #577 --- lib_eio/stream.ml | 46 +++++++++++++++++++++++++++++++++++++++++++++ lib_eio/stream.mli | 4 ++++ lib_eio/waiters.ml | 19 +++++++++++++++---- lib_eio/waiters.mli | 13 +++++++++++-- tests/stream.md | 19 +++++++++++++++++++ 5 files changed, 95 insertions(+), 6 deletions(-) diff --git a/lib_eio/stream.ml b/lib_eio/stream.ml index 974cfa3b7..8a11efc5f 100644 --- a/lib_eio/stream.ml +++ b/lib_eio/stream.ml @@ -94,6 +94,45 @@ module Locking = struct Mutex.unlock t.mutex; Some v + let select_of_many streams_fns = + let finished = Atomic.make false in + let cancel_fns = ref [] in + let add_cancel_fn fn = cancel_fns := fn :: !cancel_fns in + let cancel_all () = List.iter (fun fn -> fn ()) !cancel_fns in + let wait ctx enqueue (t, f) = begin + Mutex.lock t.mutex; + (* First check if any items are already available and return early if there are. *) + if not (Queue.is_empty t.items) + then ( + cancel_all (); + Mutex.unlock t.mutex; + enqueue (Ok (f (Queue.take t.items)))) + else add_cancel_fn @@ + (* Otherwise, register interest in this stream. *) + Waiters.cancellable_await_internal ~mutex:(Some t.mutex) t.readers t.id ctx (fun r -> + if Result.is_ok r then ( + if not (Atomic.compare_and_set finished false true) then ( + (* Another stream has yielded an item in the meantime. However, as + we have been waiting on this stream it must have been empty. + + As the stream's mutex was held since before last checking for an item, + the queue must be empty. + *) + assert ((Queue.length t.items) < t.capacity); + Queue.add (Result.get_ok r) t.items + ) else ( + (* remove all other entries of this fiber in other streams' waiters. *) + cancel_all () + )); + (* item is returned to waiting caller through enqueue and enter_unchecked. *) + enqueue (Result.map f r)) + end in + (* Register interest in all streams and return first available item. *) + let wait_for_stream streams_fns = begin + Suspend.enter_unchecked (fun ctx enqueue -> List.iter (wait ctx enqueue) streams_fns) + end in + wait_for_stream streams_fns + let length t = Mutex.lock t.mutex; let len = Queue.length t.items in @@ -125,6 +164,13 @@ let take_nonblocking = function | Sync x -> Sync.take_nonblocking x | Locking x -> Locking.take_nonblocking x +let select streams = + let filter s = match s with + | (Sync _, _) -> assert false + | (Locking x, f) -> (x, f) + in + Locking.select_of_many (List.map filter streams) + let length = function | Sync _ -> 0 | Locking x -> Locking.length x diff --git a/lib_eio/stream.mli b/lib_eio/stream.mli index 6554cac1a..79b7075b6 100644 --- a/lib_eio/stream.mli +++ b/lib_eio/stream.mli @@ -40,6 +40,10 @@ val take_nonblocking : 'a t -> 'a option Note that if another domain may add to the stream then a [None] result may already be out-of-date by the time this returns. *) +val select : ('a t * ('a -> 'b)) list -> 'b +(** [select] returns the first item yielded by any stream. This only + works for streams with non-zero capacity. *) + val length : 'a t -> int (** [length t] returns the number of items currently in [t]. *) diff --git a/lib_eio/waiters.ml b/lib_eio/waiters.ml index c0cbd4624..99c21155e 100644 --- a/lib_eio/waiters.ml +++ b/lib_eio/waiters.ml @@ -38,11 +38,12 @@ let rec wake_one t v = let is_empty = Lwt_dllist.is_empty -let await_internal ~mutex (t:'a t) id ctx enqueue = +let cancellable_await_internal ~mutex (t:'a t) id ctx enqueue = match Fiber_context.get_error ctx with | Some ex -> Option.iter Mutex.unlock mutex; - enqueue (Error ex) + enqueue (Error ex); + fun () -> () | None -> let resolved_waiter = ref Hook.null in let finished = Atomic.make false in @@ -56,14 +57,24 @@ let await_internal ~mutex (t:'a t) id ctx enqueue = enqueue (Error ex) ) in + let unwait () = + if Atomic.compare_and_set finished false true + then Hook.remove !resolved_waiter + in Fiber_context.set_cancel_fn ctx cancel; let waiter = { enqueue; finished } in match mutex with | None -> - resolved_waiter := add_waiter t waiter + resolved_waiter := add_waiter t waiter; + unwait | Some mutex -> resolved_waiter := add_waiter_protected ~mutex t waiter; - Mutex.unlock mutex + Mutex.unlock mutex; + unwait + +let await_internal ~mutex (t: 'a t) id ctx enqueue = + let _cancel = (cancellable_await_internal ~mutex t id ctx enqueue) in + () (* Returns a result if the wait succeeds, or raises if cancelled. *) let await ~mutex waiters id = diff --git a/lib_eio/waiters.mli b/lib_eio/waiters.mli index 724cf96e7..04b8d4557 100644 --- a/lib_eio/waiters.mli +++ b/lib_eio/waiters.mli @@ -27,8 +27,8 @@ val await : If [t] can be used from multiple domains: - [mutex] must be set to the mutex to use to unlock it. - [mutex] must be already held when calling this function, which will unlock it before blocking. - When [await] returns, [mutex] will have been unlocked. - @raise Cancel.Cancelled if the fiber's context is cancelled *) + When [await] returns, [mutex] will have been unlocked. + @raise Cancel.Cancelled if the fiber's context is cancelled *) val await_internal : mutex:Mutex.t option -> @@ -40,3 +40,12 @@ val await_internal : Note: [enqueue] is called from the triggering domain, which is currently calling {!wake_one} or {!wake_all} and must therefore be holding [mutex]. *) + +val cancellable_await_internal : + mutex:Mutex.t option -> + 'a t -> Ctf.id -> Fiber_context.t -> + (('a, exn) result -> unit) -> (unit -> unit) +(** Like [await_internal], but returns a function which, when called, + removes the current fiber continuation from the waiters list. + This is used when a fiber is waiting for multiple [Waiter]s simultaneously, + and needs to remove itself from other waiters once it has been enqueued by one.*) diff --git a/tests/stream.md b/tests/stream.md index c5a035e3b..10771d00d 100644 --- a/tests/stream.md +++ b/tests/stream.md @@ -357,3 +357,22 @@ Non-blocking take with zero-capacity stream: +Got None from stream - : unit = () ``` + +Selecting from multiple channels: + +```ocaml +# run @@ fun () -> Switch.run (fun sw -> + let t1, t2 = (S.create 2), (S.create 2) in + let selector = [(t1, fun x -> x); (t2, fun x -> x)] in + Fiber.fork ~sw (fun () -> S.add t2 "foo"); + Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector)); + Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector)); + Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector)); + Fiber.fork ~sw (fun () -> S.add t2 "bar"); + Fiber.fork ~sw (fun () -> S.add t1 "baz"); + ) ++foo ++bar ++baz +- : unit = () +``` From d87a3f5e23a7090b934da039361644c2296d51de Mon Sep 17 00:00:00 2001 From: Lewin Bormann Date: Thu, 20 Jul 2023 08:04:50 +0200 Subject: [PATCH 2/5] Stream.select: Fix race condition in waiter set-up. --- lib_eio/stream.ml | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/lib_eio/stream.ml b/lib_eio/stream.ml index 8a11efc5f..17bf6b943 100644 --- a/lib_eio/stream.ml +++ b/lib_eio/stream.ml @@ -104,9 +104,15 @@ module Locking = struct (* First check if any items are already available and return early if there are. *) if not (Queue.is_empty t.items) then ( - cancel_all (); - Mutex.unlock t.mutex; - enqueue (Ok (f (Queue.take t.items)))) + (* If no other stream has yielded already, we are the first one. *) + if Atomic.compare_and_set finished false true + then ( + (* Therefore, cancel all other waiters and take available item. *) + cancel_all (); + let item = Queue.take t.items in + enqueue (Ok (f item))); + Mutex.unlock t.mutex + ) else add_cancel_fn @@ (* Otherwise, register interest in this stream. *) Waiters.cancellable_await_internal ~mutex:(Some t.mutex) t.readers t.id ctx (fun r -> From d0baaab24bb8c3efbaa4ee322dd48c7f03908b13 Mon Sep 17 00:00:00 2001 From: Lewin Bormann Date: Sat, 22 Jul 2023 12:23:37 +0200 Subject: [PATCH 3/5] Wake writers after taking item in select_of_many --- lib_eio/stream.ml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib_eio/stream.ml b/lib_eio/stream.ml index 17bf6b943..3f6f7b02e 100644 --- a/lib_eio/stream.ml +++ b/lib_eio/stream.ml @@ -110,6 +110,7 @@ module Locking = struct (* Therefore, cancel all other waiters and take available item. *) cancel_all (); let item = Queue.take t.items in + ignore (Waiters.wake_one t.writers ()); enqueue (Ok (f item))); Mutex.unlock t.mutex ) @@ -128,6 +129,7 @@ module Locking = struct Queue.add (Result.get_ok r) t.items ) else ( (* remove all other entries of this fiber in other streams' waiters. *) + ignore (Waiters.wake_one t.writers ()); cancel_all () )); (* item is returned to waiting caller through enqueue and enter_unchecked. *) From b8b9b76e8f90838cec36ff8de0f54e59ccc18c1b Mon Sep 17 00:00:00 2001 From: Lewin Bormann Date: Sat, 22 Jul 2023 12:29:05 +0200 Subject: [PATCH 4/5] Only enqueue fiber if stream is first to be selected Otherwise, `Stdlib.Effect.Continuation_already_resumed` will be raised. --- lib_eio/stream.ml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib_eio/stream.ml b/lib_eio/stream.ml index 3f6f7b02e..8ebbcb8de 100644 --- a/lib_eio/stream.ml +++ b/lib_eio/stream.ml @@ -130,10 +130,10 @@ module Locking = struct ) else ( (* remove all other entries of this fiber in other streams' waiters. *) ignore (Waiters.wake_one t.writers ()); - cancel_all () - )); - (* item is returned to waiting caller through enqueue and enter_unchecked. *) - enqueue (Result.map f r)) + cancel_all (); + (* item is returned to waiting caller through enqueue and enter_unchecked. *) + enqueue (Result.map f r)) + )); end in (* Register interest in all streams and return first available item. *) let wait_for_stream streams_fns = begin From f733045f896fb3e8a194c9dc14d4915224cc9bfc Mon Sep 17 00:00:00 2001 From: Lewin Bormann Date: Sat, 22 Jul 2023 12:31:05 +0200 Subject: [PATCH 5/5] Add small benchmark for Stream.select --- bench/bench_select.ml | 57 +++++++++++++++++++++++++++++++++++++++++++ bench/main.ml | 1 + 2 files changed, 58 insertions(+) create mode 100644 bench/bench_select.ml diff --git a/bench/bench_select.ml b/bench/bench_select.ml new file mode 100644 index 000000000..ef91398ca --- /dev/null +++ b/bench/bench_select.ml @@ -0,0 +1,57 @@ + +open Eio.Stdenv +open Eio + +let sender_fibers = 4 +let cap = 10 + +let message = 1234 + +(* Send [n_msgs] items to streams in a round-robin way. *) +let sender ~n_msgs streams = + let msgs = Seq.take n_msgs (Seq.ints 0) in + let streams = Seq.cycle (List.to_seq streams) in + let zipped = Seq.zip msgs streams in + ignore (Seq.iter (fun (_i, stream) -> + Stream.add stream message) zipped) + +(* Start one sender fiber for each stream, and let it send n_msgs messages. + Each fiber sends to all streams in a round-robin way. *) +let run_senders ~dom_mgr ?(n_msgs = 100) streams = + Switch.run @@ fun sw -> + ignore @@ List.iter (fun _stream -> + Fiber.fork ~sw (fun () -> + Domain_manager.run dom_mgr (fun () -> + sender ~n_msgs streams))) streams + +(* Receive messages from all streams. *) +let receiver ~n_msgs streams = + for _i = 1 to n_msgs do + assert (Int.equal message (Stream.select streams)); + done + +(* Create [n] streams. *) +let make_streams cap n = + let unfolder i = if i == 0 then None else Some (Stream.create cap, i-1) in + let seq = Seq.unfold unfolder n in + List.of_seq seq + +let run env = + let dom_mgr = domain_mgr env in + let clock = clock env in + let streams = make_streams cap sender_fibers in + let selector = List.map (fun s -> (s, fun i -> i)) streams in + let n_msgs = 10000 in + Switch.run @@ fun sw -> + Fiber.fork ~sw (fun () -> run_senders ~dom_mgr ~n_msgs streams); + let before = Time.now clock in + receiver ~n_msgs:(sender_fibers * n_msgs) selector; + let after = Time.now clock in + let elapsed = after -. before in + let time_per_iter = elapsed /. (Float.of_int @@ sender_fibers * n_msgs) in + [Metric.create + (Printf.sprintf "sync:true senders:%d msgs_per_sender:%d" sender_fibers n_msgs) + (`Float (1e9 *. time_per_iter)) "ns" + "Time per transmitted int"] + + diff --git a/bench/main.ml b/bench/main.ml index 707253019..4d4b0bd07 100644 --- a/bench/main.ml +++ b/bench/main.ml @@ -9,6 +9,7 @@ let benchmarks = [ "Stream", Bench_stream.run; "HTTP", Bench_http.run; "Eio_unix.Fd", Bench_fd.run; + "StreamSelect", Bench_select.run; ] let usage_error () =