Skip to content

Commit 8f1e4ca

Browse files
committed
Adds full support for handle_continue/2 to gen_stage
* {:continue, _term} instructions can now be returned as one would expect from gen_server. * :hibernate is now supported on init similar to gen_server.
1 parent f4da24f commit 8f1e4ca

File tree

2 files changed

+372
-22
lines changed

2 files changed

+372
-22
lines changed

lib/gen_stage.ex

Lines changed: 169 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -885,11 +885,18 @@ defmodule GenStage do
885885

886886
@callback init(args :: term) ::
887887
{:producer, state}
888+
| {:producer, state, {:continue, term} | :hibernate}
888889
| {:producer, state, [producer_option]}
890+
| {:producer, state, {:continue, term} | :hibernate, [producer_option]}
889891
| {:producer_consumer, state}
892+
| {:producer_consumer, state, {:continue, term} | :hibernate}
890893
| {:producer_consumer, state, [producer_consumer_option]}
894+
| {:producer_consumer, state, {:continue, term} | :hibernate,
895+
[producer_consumer_option]}
891896
| {:consumer, state}
897+
| {:consumer, state, {:continue, term} | :hibernate}
892898
| {:consumer, state, [consumer_option]}
899+
| {:consumer, state, {:continue, term} | :hibernate, [consumer_option]}
893900
| :ignore
894901
| {:stop, reason :: any}
895902
when state: any
@@ -925,6 +932,7 @@ defmodule GenStage do
925932
@callback handle_demand(demand :: pos_integer, state :: term) ::
926933
{:noreply, [event], new_state}
927934
| {:noreply, [event], new_state, :hibernate}
935+
| {:noreply, [event], new_state, {:continue, term}}
928936
| {:stop, reason, new_state}
929937
when new_state: term, reason: term, event: term
930938

@@ -1004,6 +1012,7 @@ defmodule GenStage do
10041012
) ::
10051013
{:noreply, [event], new_state}
10061014
| {:noreply, [event], new_state, :hibernate}
1015+
| {:noreply, [event], new_state, {:continue, term}}
10071016
| {:stop, reason, new_state}
10081017
when event: term, new_state: term, reason: term
10091018

@@ -1017,6 +1026,7 @@ defmodule GenStage do
10171026
@callback handle_events(events :: [event], from, state :: term) ::
10181027
{:noreply, [event], new_state}
10191028
| {:noreply, [event], new_state, :hibernate}
1029+
| {:noreply, [event], new_state, {:continue, term}}
10201030
| {:stop, reason, new_state}
10211031
when new_state: term, reason: term, event: term
10221032

@@ -1056,8 +1066,10 @@ defmodule GenStage do
10561066
@callback handle_call(request :: term, from :: GenServer.from(), state :: term) ::
10571067
{:reply, reply, [event], new_state}
10581068
| {:reply, reply, [event], new_state, :hibernate}
1069+
| {:reply, reply, [event], new_state, {:continue, term}}
10591070
| {:noreply, [event], new_state}
10601071
| {:noreply, [event], new_state, :hibernate}
1072+
| {:noreply, [event], new_state, {:continue, term}}
10611073
| {:stop, reason, reply, new_state}
10621074
| {:stop, reason, new_state}
10631075
when reply: term, new_state: term, reason: term, event: term
@@ -1086,6 +1098,7 @@ defmodule GenStage do
10861098
@callback handle_cast(request :: term, state :: term) ::
10871099
{:noreply, [event], new_state}
10881100
| {:noreply, [event], new_state, :hibernate}
1101+
| {:noreply, [event], new_state, {:continue, term}}
10891102
| {:stop, reason :: term, new_state}
10901103
when new_state: term, event: term
10911104

@@ -1103,6 +1116,27 @@ defmodule GenStage do
11031116
@callback handle_info(message :: term, state :: term) ::
11041117
{:noreply, [event], new_state}
11051118
| {:noreply, [event], new_state, :hibernate}
1119+
| {:noreply, [event], new_state, {:continue, term}}
1120+
| {:stop, reason :: term, new_state}
1121+
when new_state: term, event: term
1122+
1123+
@doc """
1124+
Invoked to handle `continue` instructions.
1125+
1126+
It is useful for performing work after initialization or for splitting the work
1127+
in a callback in multiple steps, updating the process state along the way.
1128+
1129+
Return values are the same as `c:handle_cast/2`.
1130+
1131+
This callback is optional. If one is not implemented, the server will fail
1132+
if a continue instruction is used.
1133+
1134+
This callback is only supported on Erlang/OTP 21+.
1135+
"""
1136+
@callback handle_continue(continue :: term, state :: term) ::
1137+
{:noreply, [event], new_state}
1138+
| {:noreply, [event], new_state, :hibernate}
1139+
| {:noreply, [event], new_state, {:continue, term}}
11061140
| {:stop, reason :: term, new_state}
11071141
when new_state: term, event: term
11081142

@@ -1139,6 +1173,7 @@ defmodule GenStage do
11391173
format_status: 2,
11401174
handle_call: 3,
11411175
handle_cast: 2,
1176+
handle_continue: 2,
11421177
handle_info: 2,
11431178
terminate: 2
11441179
]
@@ -1722,22 +1757,58 @@ defmodule GenStage do
17221757
def init({mod, args}) do
17231758
case mod.init(args) do
17241759
{:producer, state} ->
1725-
init_producer(mod, [], state)
1760+
init_producer(mod, [], state, nil)
1761+
1762+
{:producer, state, {:continue, _term} = continue} ->
1763+
init_producer(mod, [], state, continue)
1764+
1765+
{:producer, state, :hibernate} ->
1766+
init_producer(mod, [], state, :hibernate)
17261767

17271768
{:producer, state, opts} when is_list(opts) ->
1728-
init_producer(mod, opts, state)
1769+
init_producer(mod, opts, state, nil)
1770+
1771+
{:producer, state, {:continue, _term} = continue, opts} when is_list(opts) ->
1772+
init_producer(mod, opts, state, continue)
1773+
1774+
{:producer, state, :hibernate, opts} when is_list(opts) ->
1775+
init_producer(mod, opts, state, :hibernate)
17291776

17301777
{:producer_consumer, state} ->
1731-
init_producer_consumer(mod, [], state)
1778+
init_producer_consumer(mod, [], state, nil)
1779+
1780+
{:producer_consumer, state, {:continue, _term} = continue} ->
1781+
init_producer_consumer(mod, [], state, continue)
1782+
1783+
{:producer_consumer, state, :hibernate} ->
1784+
init_producer_consumer(mod, [], state, :hibernate)
17321785

17331786
{:producer_consumer, state, opts} when is_list(opts) ->
1734-
init_producer_consumer(mod, opts, state)
1787+
init_producer_consumer(mod, opts, state, nil)
1788+
1789+
{:producer_consumer, state, {:continue, _term} = continue, opts} when is_list(opts) ->
1790+
init_producer_consumer(mod, opts, state, continue)
1791+
1792+
{:producer_consumer, state, :hibernate, opts} when is_list(opts) ->
1793+
init_producer_consumer(mod, opts, state, :hibernate)
17351794

17361795
{:consumer, state} ->
1737-
init_consumer(mod, [], state)
1796+
init_consumer(mod, [], state, nil)
1797+
1798+
{:consumer, state, {:continue, _term} = continue} ->
1799+
init_consumer(mod, [], state, continue)
1800+
1801+
{:consumer, state, :hibernate} ->
1802+
init_consumer(mod, [], state, :hibernate)
17381803

17391804
{:consumer, state, opts} when is_list(opts) ->
1740-
init_consumer(mod, opts, state)
1805+
init_consumer(mod, opts, state, nil)
1806+
1807+
{:consumer, state, {:continue, _term} = continue, opts} when is_list(opts) ->
1808+
init_consumer(mod, opts, state, continue)
1809+
1810+
{:consumer, state, :hibernate, opts} when is_list(opts) ->
1811+
init_consumer(mod, opts, state, :hibernate)
17411812

17421813
{:stop, _} = stop ->
17431814
stop
@@ -1750,7 +1821,7 @@ defmodule GenStage do
17501821
end
17511822
end
17521823

1753-
defp init_producer(mod, opts, state) do
1824+
defp init_producer(mod, opts, state, continue_or_hibernate) do
17541825
with {:ok, dispatcher_mod, dispatcher_state, opts} <- init_dispatcher(opts),
17551826
{:ok, buffer_size, opts} <-
17561827
Utils.validate_integer(opts, :buffer_size, 10000, 0, :infinity, true),
@@ -1770,7 +1841,7 @@ defmodule GenStage do
17701841
dispatcher_state: dispatcher_state
17711842
}
17721843

1773-
{:ok, stage}
1844+
if continue_or_hibernate, do: {:ok, stage, continue_or_hibernate}, else: {:ok, stage}
17741845
else
17751846
{:error, message} -> {:stop, {:bad_opts, message}}
17761847
end
@@ -1792,7 +1863,7 @@ defmodule GenStage do
17921863
end
17931864
end
17941865

1795-
defp init_producer_consumer(mod, opts, state) do
1866+
defp init_producer_consumer(mod, opts, state, continue_or_hibernate) do
17961867
with {:ok, dispatcher_mod, dispatcher_state, opts} <- init_dispatcher(opts),
17971868
{:ok, subscribe_to, opts} <- Utils.validate_list(opts, :subscribe_to, []),
17981869
{:ok, buffer_size, opts} <-
@@ -1811,22 +1882,68 @@ defmodule GenStage do
18111882
dispatcher_state: dispatcher_state
18121883
}
18131884

1814-
consumer_init_subscribe(subscribe_to, stage)
1885+
case handle_gen_server_init_args(continue_or_hibernate, stage) do
1886+
{:ok, stage} ->
1887+
consumer_init_subscribe(subscribe_to, stage)
1888+
1889+
{:ok, stage, args} ->
1890+
{:ok, stage} = consumer_init_subscribe(subscribe_to, stage)
1891+
{:ok, stage, args}
1892+
1893+
{:stop, _, _} = error ->
1894+
error
1895+
end
18151896
else
18161897
{:error, message} -> {:stop, {:bad_opts, message}}
18171898
end
18181899
end
18191900

1820-
defp init_consumer(mod, opts, state) do
1901+
defp init_consumer(mod, opts, state, continue_or_hibernate) do
18211902
with {:ok, subscribe_to, opts} <- Utils.validate_list(opts, :subscribe_to, []),
18221903
:ok <- Utils.validate_no_opts(opts) do
18231904
stage = %GenStage{mod: mod, state: state, type: :consumer}
1824-
consumer_init_subscribe(subscribe_to, stage)
1905+
1906+
case handle_gen_server_init_args(continue_or_hibernate, stage) do
1907+
{:ok, stage} ->
1908+
consumer_init_subscribe(subscribe_to, stage)
1909+
1910+
{:ok, stage, args} ->
1911+
{:ok, stage} = consumer_init_subscribe(subscribe_to, stage)
1912+
{:ok, stage, args}
1913+
1914+
{:stop, _, _} = error ->
1915+
error
1916+
end
18251917
else
18261918
{:error, message} -> {:stop, {:bad_opts, message}}
18271919
end
18281920
end
18291921

1922+
defp handle_gen_server_init_args({:continue, _term} = continue, stage) do
1923+
case handle_continue(continue, stage) do
1924+
{:noreply, stage} ->
1925+
{:ok, stage}
1926+
1927+
{:noreply, stage, :hibernate} ->
1928+
{:ok, stage, :hibernate}
1929+
1930+
{:noreply, stage, {:continue, _term} = continue} ->
1931+
{:ok, stage, continue}
1932+
1933+
{:stop, reason, stage} ->
1934+
{:stop, reason, stage}
1935+
end
1936+
end
1937+
1938+
defp handle_gen_server_init_args(:hibernate, stage), do: {:ok, stage, :hibernate}
1939+
defp handle_gen_server_init_args(nil, stage), do: {:ok, stage}
1940+
1941+
@doc false
1942+
1943+
def handle_continue(continue, %{state: state} = stage) do
1944+
noreply_callback(:handle_continue, [continue, state], stage)
1945+
end
1946+
18301947
@doc false
18311948

18321949
def handle_call({:"$info", msg}, _from, stage) do
@@ -1855,6 +1972,10 @@ defmodule GenStage do
18551972
stage = dispatch_events(events, length(events), %{stage | state: state})
18561973
{:reply, reply, stage, :hibernate}
18571974

1975+
{:reply, reply, events, state, {:continue, _term} = continue} ->
1976+
stage = dispatch_events(events, length(events), %{stage | state: state})
1977+
{:reply, reply, stage, continue}
1978+
18581979
{:stop, reason, reply, state} ->
18591980
{:stop, reason, reply, %{stage | state: state}}
18601981

@@ -1995,7 +2116,7 @@ defmodule GenStage do
19952116
case producers do
19962117
%{^ref => entry} ->
19972118
{batches, stage} = consumer_receive(from, entry, events, stage)
1998-
consumer_dispatch(batches, from, mod, state, stage, false)
2119+
consumer_dispatch(batches, from, mod, state, stage, nil)
19992120

20002121
_ ->
20012122
msg = {:"$gen_producer", {self(), ref}, {:cancel, :unknown_subscription}}
@@ -2122,6 +2243,14 @@ defmodule GenStage do
21222243
end
21232244
end
21242245

2246+
defp noreply_callback(:handle_continue, [continue, state], %{mod: mod} = stage) do
2247+
if function_exported?(mod, :handle_continue, 2) do
2248+
handle_noreply_callback(mod.handle_continue(continue, state), stage)
2249+
else
2250+
:error_handler.raise_undef_exception(mod, :handle_continue, [continue, state])
2251+
end
2252+
end
2253+
21252254
defp noreply_callback(callback, args, %{mod: mod} = stage) do
21262255
handle_noreply_callback(apply(mod, callback, args), stage)
21272256
end
@@ -2136,6 +2265,10 @@ defmodule GenStage do
21362265
stage = dispatch_events(events, length(events), %{stage | state: state})
21372266
{:noreply, stage, :hibernate}
21382267

2268+
{:noreply, events, state, {:continue, _term} = continue} when is_list(events) ->
2269+
stage = dispatch_events(events, length(events), %{stage | state: state})
2270+
{:noreply, stage, continue}
2271+
21392272
{:stop, reason, state} ->
21402273
{:stop, reason, %{stage | state: state}}
21412274

@@ -2259,6 +2392,9 @@ defmodule GenStage do
22592392
# main module must know the consumer is no longer subscribed.
22602393
dispatcher_callback(:cancel, [{pid, ref}, dispatcher_state], stage)
22612394

2395+
{:noreply, %{dispatcher_state: dispatcher_state} = stage, _hibernate_or_continue} ->
2396+
dispatcher_callback(:cancel, [{pid, ref}, dispatcher_state], stage)
2397+
22622398
{:stop, _, _} = stop ->
22632399
stop
22642400
end
@@ -2459,17 +2595,22 @@ defmodule GenStage do
24592595
{[{events, 0}], stage}
24602596
end
24612597

2462-
defp consumer_dispatch([{batch, ask} | batches], from, mod, state, stage, _hibernate?) do
2598+
defp consumer_dispatch([{batch, ask} | batches], from, mod, state, stage, _gen_opts) do
24632599
case mod.handle_events(batch, from, state) do
24642600
{:noreply, events, state} when is_list(events) ->
24652601
stage = dispatch_events(events, length(events), stage)
24662602
ask(from, ask, [:noconnect])
2467-
consumer_dispatch(batches, from, mod, state, stage, false)
2603+
consumer_dispatch(batches, from, mod, state, stage, nil)
24682604

2469-
{:noreply, events, state, :hibernate} when is_list(events) ->
2605+
{:noreply, events, state, :hibernate} ->
24702606
stage = dispatch_events(events, length(events), stage)
24712607
ask(from, ask, [:noconnect])
2472-
consumer_dispatch(batches, from, mod, state, stage, true)
2608+
consumer_dispatch(batches, from, mod, state, stage, :hibernate)
2609+
2610+
{:noreply, events, state, {:continue, _} = continue} ->
2611+
stage = dispatch_events(events, length(events), stage)
2612+
ask(from, ask, [:noconnect])
2613+
consumer_dispatch(batches, from, mod, state, stage, continue)
24732614

24742615
{:stop, reason, state} ->
24752616
{:stop, reason, %{stage | state: state}}
@@ -2479,12 +2620,12 @@ defmodule GenStage do
24792620
end
24802621
end
24812622

2482-
defp consumer_dispatch([], _from, _mod, state, stage, false) do
2623+
defp consumer_dispatch([], _from, _mod, state, stage, nil) do
24832624
{:noreply, %{stage | state: state}}
24842625
end
24852626

2486-
defp consumer_dispatch([], _from, _mod, state, stage, true) do
2487-
{:noreply, %{stage | state: state}, :hibernate}
2627+
defp consumer_dispatch([], _from, _mod, state, stage, gen_opts) do
2628+
{:noreply, %{stage | state: state}, gen_opts}
24882629
end
24892630

24902631
defp consumer_subscribe({to, opts}, stage) when is_list(opts),
@@ -2613,11 +2754,11 @@ defmodule GenStage do
26132754
{producer_id, _, _} = entry
26142755
from = {producer_id, ref}
26152756
{batches, stage} = consumer_receive(from, entry, events, stage)
2616-
consumer_dispatch(batches, from, mod, state, stage, false)
2757+
consumer_dispatch(batches, from, mod, state, stage, nil)
26172758

26182759
%{} ->
26192760
# We queued but producer was removed
2620-
consumer_dispatch([{events, 0}], {:pid, ref}, mod, state, stage, false)
2761+
consumer_dispatch([{events, 0}], {:pid, ref}, mod, state, stage, nil)
26212762
end
26222763
end
26232764

@@ -2634,6 +2775,9 @@ defmodule GenStage do
26342775
{:noreply, stage, :hibernate} ->
26352776
take_pc_events(queue, counter, stage)
26362777

2778+
{:noreply, stage, {:continue, _term}} ->
2779+
take_pc_events(queue, counter, stage)
2780+
26372781
{:stop, _, _} = stop ->
26382782
stop
26392783
end
@@ -2646,6 +2790,9 @@ defmodule GenStage do
26462790
{:noreply, %{events: {queue, counter}} = stage, :hibernate} ->
26472791
take_pc_events(queue, counter, stage)
26482792

2793+
{:noreply, %{events: {queue, counter}} = stage, {:continue, _term}} ->
2794+
take_pc_events(queue, counter, stage)
2795+
26492796
{:stop, _, _} = stop ->
26502797
stop
26512798
end

0 commit comments

Comments
 (0)