Skip to content

Commit 498b16a

Browse files
committed
Add :on_cancel
1 parent c31cb55 commit 498b16a

File tree

3 files changed

+56
-8
lines changed

3 files changed

+56
-8
lines changed

lib/gen_stage.ex

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,12 @@ defmodule GenStage do
16941694
* `:stacktrace` - the stacktrace of the function that started the
16951695
stream.
16961696
1697+
* `:on_cancel` - what happens when all consumers cancel. The default
1698+
is to keep the stream running. Set it to `:stop` to stop the producer.
1699+
To avoid race conditions, it is recommend to only set this option if
1700+
`:demand` is set to `:accumulate` and forwarded only after all consumers
1701+
subscribe
1702+
16971703
All other options that would be given for `start_link/3` are
16981704
also accepted.
16991705
"""

lib/gen_stage/streamer.ex

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,33 +14,60 @@ defmodule GenStage.Streamer do
1414
x, {acc, counter} -> {:cont, {[x | acc], counter - 1}}
1515
end)
1616

17-
{:producer, {stack, continuation}, Keyword.take(opts, [:dispatcher, :demand])}
17+
on_cancel =
18+
case Keyword.get(opts, :on_cancel, :continue) do
19+
:continue -> nil
20+
:stop -> %{}
21+
end
22+
23+
{:producer, {stack, continuation, on_cancel}, Keyword.take(opts, [:dispatcher, :demand])}
24+
end
25+
26+
def handle_subscribe(:consumer, _opts, {pid, ref}, {stack, continuation, on_cancel}) do
27+
if on_cancel do
28+
{:automatic, {stack, continuation, Map.put(on_cancel, ref, pid)}}
29+
else
30+
{:automatic, {stack, continuation, on_cancel}}
31+
end
32+
end
33+
34+
def handle_cancel(_reason, {_, ref}, {stack, continuation, on_cancel}) do
35+
case on_cancel do
36+
%{^ref => _} when map_size(on_cancel) == 1 ->
37+
{:stop, :normal, {stack, continuation, Map.delete(on_cancel, ref)}}
38+
39+
%{^ref => _} ->
40+
{:noreply, [], {stack, continuation, Map.delete(on_cancel, ref)}}
41+
42+
_ ->
43+
{:noreply, [], {stack, continuation, on_cancel}}
44+
end
1845
end
1946

20-
def handle_demand(_demand, {stack, continuation}) when is_atom(continuation) do
21-
{:noreply, [], {stack, continuation}}
47+
def handle_demand(_demand, {stack, continuation, on_cancel}) when is_atom(continuation) do
48+
{:noreply, [], {stack, continuation, on_cancel}}
2249
end
2350

24-
def handle_demand(demand, {stack, continuation}) when demand > 0 do
51+
def handle_demand(demand, {stack, continuation, on_cancel}) when demand > 0 do
2552
case continuation.({:cont, {[], demand}}) do
2653
{:suspended, {list, 0}, continuation} ->
27-
{:noreply, :lists.reverse(list), {stack, continuation}}
54+
{:noreply, :lists.reverse(list), {stack, continuation, on_cancel}}
2855

2956
{status, {list, _}} ->
3057
GenStage.async_info(self(), :stop)
31-
{:noreply, :lists.reverse(list), {stack, status}}
58+
{:noreply, :lists.reverse(list), {stack, status, on_cancel}}
3259
end
3360
end
3461

3562
def handle_info(:stop, state) do
3663
{:stop, :normal, state}
3764
end
3865

39-
def handle_info(msg, {stack, continuation}) do
66+
def handle_info(msg, {stack, continuation, on_cancel}) do
4067
log =
4168
~c"** Undefined handle_info in ~tp~n** Unhandled message: ~tp~n** Stream started at:~n~ts"
4269

4370
:error_logger.warning_msg(log, [inspect(__MODULE__), msg, Exception.format_stacktrace(stack)])
44-
{:noreply, [], {stack, continuation}}
71+
{:noreply, [], {stack, continuation, on_cancel}}
4572
end
4673
end

test/gen_stage_test.exs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,6 +1905,21 @@ defmodule GenStageTest do
19051905
assert Process.info(producer, :registered_name) ==
19061906
{:registered_name, :gen_stage_from_enumerable}
19071907
end
1908+
1909+
test "accepts a :on_cancel option" do
1910+
{:ok, pid} = GenStage.from_enumerable(Stream.cycle([1, 2, 3]))
1911+
assert [pid] |> GenStage.stream() |> Enum.take(5) == [1, 2, 3, 1, 2]
1912+
assert Process.alive?(pid)
1913+
1914+
{:ok, pid} = GenStage.from_enumerable(Stream.cycle([1, 2, 3]), on_cancel: :continue)
1915+
assert [pid] |> GenStage.stream() |> Enum.take(5) == [1, 2, 3, 1, 2]
1916+
assert Process.alive?(pid)
1917+
1918+
{:ok, pid} = GenStage.from_enumerable(Stream.cycle([1, 2, 3]), on_cancel: :stop)
1919+
assert [pid] |> GenStage.stream() |> Enum.take(5) == [1, 2, 3, 1, 2]
1920+
ref = Process.monitor(pid)
1921+
assert_receive {:DOWN, ^ref, _, _, _}
1922+
end
19081923
end
19091924

19101925
describe "subscribe_to names" do

0 commit comments

Comments
 (0)