Skip to content

Commit d9cfe7c

Browse files
authored
Avoid sending events in accumulate demand (#296)
1 parent f87ccd5 commit d9cfe7c

File tree

2 files changed

+88
-15
lines changed

2 files changed

+88
-15
lines changed

lib/gen_stage.ex

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1339,9 +1339,9 @@ defmodule GenStage do
13391339
Sets the demand mode for a producer.
13401340
13411341
When `:forward`, the demand is always forwarded to the `c:handle_demand/2`
1342-
callback. When `:accumulate`, demand is accumulated until its mode is
1343-
set to `:forward`. This is useful as a synchronization mechanism, where
1344-
the demand is accumulated until all consumers are subscribed. Defaults
1342+
callback. When `:accumulate`, both demand and events are accumulated until
1343+
its mode is set to `:forward`. This is useful as a synchronization mechanism,
1344+
where the demand is accumulated until all consumers are subscribed. Defaults
13451345
to `:forward`.
13461346
13471347
This command is asynchronous.
@@ -2261,11 +2261,11 @@ defmodule GenStage do
22612261

22622262
if is_list(events) do
22632263
fold_fun = fn
2264-
d, {:noreply, %{state: state} = stage} ->
2265-
noreply_callback(:handle_demand, [d, state], stage)
2264+
event, {:noreply, stage} ->
2265+
handle_accumulated_event(event, stage)
22662266

2267-
d, {:noreply, %{state: state} = stage, _} ->
2268-
noreply_callback(:handle_demand, [d, state], stage)
2267+
event, {:noreply, stage, _} ->
2268+
handle_accumulated_event(event, stage)
22692269

22702270
_, {:stop, _, _} = acc ->
22712271
acc
@@ -2285,6 +2285,14 @@ defmodule GenStage do
22852285
end
22862286
end
22872287

2288+
defp handle_accumulated_event({:demand, d}, stage) do
2289+
take_from_buffer_or_handle_demand(d, stage)
2290+
end
2291+
2292+
defp handle_accumulated_event({:dispatch, events, length}, stage) do
2293+
{:noreply, dispatch_events(events, length, stage)}
2294+
end
2295+
22882296
defp producer_subscribe(opts, from, stage) do
22892297
%{mod: mod, state: state, dispatcher_mod: dispatcher_mod, dispatcher_state: dispatcher_state} =
22902298
stage
@@ -2370,23 +2378,33 @@ defmodule GenStage do
23702378
take_pc_events(queue, counter, stage)
23712379

23722380
%{} ->
2373-
case take_from_buffer(counter, %{stage | dispatcher_state: dispatcher_state}) do
2374-
{:ok, 0, stage} ->
2375-
{:noreply, stage}
2381+
take_from_buffer_or_handle_demand(counter, %{stage | dispatcher_state: dispatcher_state})
2382+
end
2383+
end
23762384

2377-
{:ok, counter, %{events: :forward, state: state} = stage} ->
2378-
noreply_callback(:handle_demand, [counter, state], stage)
2385+
defp take_from_buffer_or_handle_demand(counter, stage) do
2386+
case take_from_buffer(counter, stage) do
2387+
{:ok, 0, stage} ->
2388+
{:noreply, stage}
23792389

2380-
{:ok, counter, %{events: events} = stage} when is_list(events) ->
2381-
{:noreply, %{stage | events: [counter | events]}}
2382-
end
2390+
{:ok, counter, %{events: :forward, state: state} = stage} ->
2391+
noreply_callback(:handle_demand, [counter, state], stage)
2392+
2393+
{:ok, counter, %{events: events} = stage} when is_list(events) ->
2394+
{:noreply, %{stage | events: [{:demand, counter} | events]}}
23832395
end
23842396
end
23852397

23862398
defp dispatch_events([], _length, stage) do
23872399
stage
23882400
end
23892401

2402+
# We don't dispatch when we are accumulating demand
2403+
defp dispatch_events(to_dispatch, length, %{events: events, type: :producer} = stage)
2404+
when is_list(events) do
2405+
%{stage | events: [{:dispatch, to_dispatch, length} | events]}
2406+
end
2407+
23902408
defp dispatch_events(events, _length, %{type: :consumer} = stage) do
23912409
error_msg =
23922410
~c"GenStage consumer ~tp cannot dispatch events (an empty list must be returned): ~tp~n"

test/gen_stage_test.exs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,29 @@ defmodule GenStageTest do
33

44
import ExUnit.CaptureLog
55

6+
defmodule EventProducer do
7+
@moduledoc """
8+
Produce events when receives a cast
9+
"""
10+
use GenStage
11+
12+
def start_link(init) do
13+
GenStage.start_link(__MODULE__, init)
14+
end
15+
16+
@impl GenStage
17+
@doc false
18+
def init(init), do: init
19+
20+
@impl GenStage
21+
@doc false
22+
def handle_cast(event, state), do: {:noreply, [event], state}
23+
24+
@impl GenStage
25+
@doc false
26+
def handle_demand(_demand, state), do: {:noreply, [], state}
27+
end
28+
629
defmodule Counter do
730
@moduledoc """
831
A producer that works as a counter in batches.
@@ -673,6 +696,38 @@ defmodule GenStageTest do
673696
assert_receive {:consumed, [0, 1, 2, 3]}
674697
end
675698

699+
test "can be set to :accumulate via API using broadcast" do
700+
{:ok, producer} =
701+
EventProducer.start_link({:producer, [], dispatcher: GenStage.BroadcastDispatcher})
702+
703+
assert GenStage.demand(producer) == :forward
704+
{:ok, consumer1} = Forwarder.start_link({:consumer, self(), subscribe_to: [producer]})
705+
{:ok, consumer2} = Forwarder.start_link({:consumer, self(), subscribe_to: [producer]})
706+
GenStage.demand(producer, :accumulate)
707+
assert GenStage.demand(producer) == :accumulate
708+
709+
GenStage.stop(consumer1)
710+
GenStage.stop(consumer2)
711+
712+
assert :ok = GenStage.cast(producer, 1)
713+
assert :ok = GenStage.cast(producer, 2)
714+
assert :ok = GenStage.cast(producer, 3)
715+
assert :ok = GenStage.cast(producer, 4)
716+
717+
Process.sleep(200)
718+
{:ok, _consumer1} = Forwarder.start_link({:consumer, self(), subscribe_to: [producer]})
719+
{:ok, _consumer2} = Forwarder.start_link({:consumer, self(), subscribe_to: [producer]})
720+
refute_receive {:consumed, _}
721+
722+
GenStage.demand(producer, :forward)
723+
assert GenStage.demand(producer) == :forward
724+
assert {{[], []}, 0, _} = :sys.get_state(producer).buffer
725+
assert_receive {:consumed, [1]}
726+
assert_receive {:consumed, [2]}
727+
assert_receive {:consumed, [3]}
728+
assert_receive {:consumed, [4]}
729+
end
730+
676731
test "can be set to :accumulate via API" do
677732
{:ok, producer} = Counter.start_link({:producer, 0})
678733
assert GenStage.demand(producer) == :forward

0 commit comments

Comments
 (0)