Skip to content

Commit c48f7a8

Browse files
hazardfnmaartenvanvliet
authored andcommitted
Adds full support for handle_continue/2 to gen_stage
* {:continue, _term} instructions can now be returned as one would expect from gen_server. * :hibernate is now supported on init similar to gen_server.
1 parent 4a5da60 commit c48f7a8

File tree

2 files changed

+372
-22
lines changed

2 files changed

+372
-22
lines changed

lib/gen_stage.ex

Lines changed: 169 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -909,11 +909,18 @@ defmodule GenStage do
909909

910910
@callback init(args :: term) ::
911911
{:producer, state}
912+
| {:producer, state, {:continue, term} | :hibernate}
912913
| {:producer, state, [producer_option]}
914+
| {:producer, state, {:continue, term} | :hibernate, [producer_option]}
913915
| {:producer_consumer, state}
916+
| {:producer_consumer, state, {:continue, term} | :hibernate}
914917
| {:producer_consumer, state, [producer_consumer_option]}
918+
| {:producer_consumer, state, {:continue, term} | :hibernate,
919+
[producer_consumer_option]}
915920
| {:consumer, state}
921+
| {:consumer, state, {:continue, term} | :hibernate}
916922
| {:consumer, state, [consumer_option]}
923+
| {:consumer, state, {:continue, term} | :hibernate, [consumer_option]}
917924
| :ignore
918925
| {:stop, reason :: any}
919926
when state: any
@@ -996,6 +1003,7 @@ defmodule GenStage do
9961003
@callback handle_demand(demand :: pos_integer, state :: term) ::
9971004
{:noreply, [event], new_state}
9981005
| {:noreply, [event], new_state, :hibernate}
1006+
| {:noreply, [event], new_state, {:continue, term}}
9991007
| {:stop, reason, new_state}
10001008
when new_state: term, reason: term, event: term
10011009

@@ -1075,6 +1083,7 @@ defmodule GenStage do
10751083
) ::
10761084
{:noreply, [event], new_state}
10771085
| {:noreply, [event], new_state, :hibernate}
1086+
| {:noreply, [event], new_state, {:continue, term}}
10781087
| {:stop, reason, new_state}
10791088
when event: term, new_state: term, reason: term
10801089

@@ -1088,6 +1097,7 @@ defmodule GenStage do
10881097
@callback handle_events(events :: [event], from, state :: term) ::
10891098
{:noreply, [event], new_state}
10901099
| {:noreply, [event], new_state, :hibernate}
1100+
| {:noreply, [event], new_state, {:continue, term}}
10911101
| {:stop, reason, new_state}
10921102
when new_state: term, reason: term, event: term
10931103

@@ -1129,8 +1139,10 @@ defmodule GenStage do
11291139
@callback handle_call(request :: term, from :: GenServer.from(), state :: term) ::
11301140
{:reply, reply, [event], new_state}
11311141
| {:reply, reply, [event], new_state, :hibernate}
1142+
| {:reply, reply, [event], new_state, {:continue, term}}
11321143
| {:noreply, [event], new_state}
11331144
| {:noreply, [event], new_state, :hibernate}
1145+
| {:noreply, [event], new_state, {:continue, term}}
11341146
| {:stop, reason, reply, new_state}
11351147
| {:stop, reason, new_state}
11361148
when reply: term, new_state: term, reason: term, event: term
@@ -1161,6 +1173,7 @@ defmodule GenStage do
11611173
@callback handle_cast(request :: term, state :: term) ::
11621174
{:noreply, [event], new_state}
11631175
| {:noreply, [event], new_state, :hibernate}
1176+
| {:noreply, [event], new_state, {:continue, term}}
11641177
| {:stop, reason :: term, new_state}
11651178
when new_state: term, event: term
11661179

@@ -1181,6 +1194,27 @@ defmodule GenStage do
11811194
@callback handle_info(message :: term, state :: term) ::
11821195
{:noreply, [event], new_state}
11831196
| {:noreply, [event], new_state, :hibernate}
1197+
| {:noreply, [event], new_state, {:continue, term}}
1198+
| {:stop, reason :: term, new_state}
1199+
when new_state: term, event: term
1200+
1201+
@doc """
1202+
Invoked to handle `continue` instructions.
1203+
1204+
It is useful for performing work after initialization or for splitting the work
1205+
in a callback in multiple steps, updating the process state along the way.
1206+
1207+
Return values are the same as `c:handle_cast/2`.
1208+
1209+
This callback is optional. If one is not implemented, the server will fail
1210+
if a continue instruction is used.
1211+
1212+
This callback is only supported on Erlang/OTP 21+.
1213+
"""
1214+
@callback handle_continue(continue :: term, state :: term) ::
1215+
{:noreply, [event], new_state}
1216+
| {:noreply, [event], new_state, :hibernate}
1217+
| {:noreply, [event], new_state, {:continue, term}}
11841218
| {:stop, reason :: term, new_state}
11851219
when new_state: term, event: term
11861220

@@ -1217,6 +1251,7 @@ defmodule GenStage do
12171251
format_status: 2,
12181252
handle_call: 3,
12191253
handle_cast: 2,
1254+
handle_continue: 2,
12201255
handle_info: 2,
12211256
terminate: 2
12221257
]
@@ -1815,22 +1850,58 @@ defmodule GenStage do
18151850
def init({mod, args}) do
18161851
case mod.init(args) do
18171852
{:producer, state} ->
1818-
init_producer(mod, [], state)
1853+
init_producer(mod, [], state, nil)
1854+
1855+
{:producer, state, {:continue, _term} = continue} ->
1856+
init_producer(mod, [], state, continue)
1857+
1858+
{:producer, state, :hibernate} ->
1859+
init_producer(mod, [], state, :hibernate)
18191860

18201861
{:producer, state, opts} when is_list(opts) ->
1821-
init_producer(mod, opts, state)
1862+
init_producer(mod, opts, state, nil)
1863+
1864+
{:producer, state, {:continue, _term} = continue, opts} when is_list(opts) ->
1865+
init_producer(mod, opts, state, continue)
1866+
1867+
{:producer, state, :hibernate, opts} when is_list(opts) ->
1868+
init_producer(mod, opts, state, :hibernate)
18221869

18231870
{:producer_consumer, state} ->
1824-
init_producer_consumer(mod, [], state)
1871+
init_producer_consumer(mod, [], state, nil)
1872+
1873+
{:producer_consumer, state, {:continue, _term} = continue} ->
1874+
init_producer_consumer(mod, [], state, continue)
1875+
1876+
{:producer_consumer, state, :hibernate} ->
1877+
init_producer_consumer(mod, [], state, :hibernate)
18251878

18261879
{:producer_consumer, state, opts} when is_list(opts) ->
1827-
init_producer_consumer(mod, opts, state)
1880+
init_producer_consumer(mod, opts, state, nil)
1881+
1882+
{:producer_consumer, state, {:continue, _term} = continue, opts} when is_list(opts) ->
1883+
init_producer_consumer(mod, opts, state, continue)
1884+
1885+
{:producer_consumer, state, :hibernate, opts} when is_list(opts) ->
1886+
init_producer_consumer(mod, opts, state, :hibernate)
18281887

18291888
{:consumer, state} ->
1830-
init_consumer(mod, [], state)
1889+
init_consumer(mod, [], state, nil)
1890+
1891+
{:consumer, state, {:continue, _term} = continue} ->
1892+
init_consumer(mod, [], state, continue)
1893+
1894+
{:consumer, state, :hibernate} ->
1895+
init_consumer(mod, [], state, :hibernate)
18311896

18321897
{:consumer, state, opts} when is_list(opts) ->
1833-
init_consumer(mod, opts, state)
1898+
init_consumer(mod, opts, state, nil)
1899+
1900+
{:consumer, state, {:continue, _term} = continue, opts} when is_list(opts) ->
1901+
init_consumer(mod, opts, state, continue)
1902+
1903+
{:consumer, state, :hibernate, opts} when is_list(opts) ->
1904+
init_consumer(mod, opts, state, :hibernate)
18341905

18351906
{:stop, _} = stop ->
18361907
stop
@@ -1843,7 +1914,7 @@ defmodule GenStage do
18431914
end
18441915
end
18451916

1846-
defp init_producer(mod, opts, state) do
1917+
defp init_producer(mod, opts, state, continue_or_hibernate) do
18471918
with {:ok, dispatcher_mod, dispatcher_state, opts} <- init_dispatcher(opts),
18481919
{:ok, buffer_size, opts} <-
18491920
Utils.validate_integer(opts, :buffer_size, 10000, 0, :infinity, true),
@@ -1863,7 +1934,7 @@ defmodule GenStage do
18631934
dispatcher_state: dispatcher_state
18641935
}
18651936

1866-
{:ok, stage}
1937+
if continue_or_hibernate, do: {:ok, stage, continue_or_hibernate}, else: {:ok, stage}
18671938
else
18681939
{:error, message} -> {:stop, {:bad_opts, message}}
18691940
end
@@ -1885,7 +1956,7 @@ defmodule GenStage do
18851956
end
18861957
end
18871958

1888-
defp init_producer_consumer(mod, opts, state) do
1959+
defp init_producer_consumer(mod, opts, state, continue_or_hibernate) do
18891960
with {:ok, dispatcher_mod, dispatcher_state, opts} <- init_dispatcher(opts),
18901961
{:ok, subscribe_to, opts} <- Utils.validate_list(opts, :subscribe_to, []),
18911962
{:ok, buffer_size, opts} <-
@@ -1904,22 +1975,68 @@ defmodule GenStage do
19041975
dispatcher_state: dispatcher_state
19051976
}
19061977

1907-
consumer_init_subscribe(subscribe_to, stage)
1978+
case handle_gen_server_init_args(continue_or_hibernate, stage) do
1979+
{:ok, stage} ->
1980+
consumer_init_subscribe(subscribe_to, stage)
1981+
1982+
{:ok, stage, args} ->
1983+
{:ok, stage} = consumer_init_subscribe(subscribe_to, stage)
1984+
{:ok, stage, args}
1985+
1986+
{:stop, _, _} = error ->
1987+
error
1988+
end
19081989
else
19091990
{:error, message} -> {:stop, {:bad_opts, message}}
19101991
end
19111992
end
19121993

1913-
defp init_consumer(mod, opts, state) do
1994+
defp init_consumer(mod, opts, state, continue_or_hibernate) do
19141995
with {:ok, subscribe_to, opts} <- Utils.validate_list(opts, :subscribe_to, []),
19151996
:ok <- Utils.validate_no_opts(opts) do
19161997
stage = %GenStage{mod: mod, state: state, type: :consumer}
1917-
consumer_init_subscribe(subscribe_to, stage)
1998+
1999+
case handle_gen_server_init_args(continue_or_hibernate, stage) do
2000+
{:ok, stage} ->
2001+
consumer_init_subscribe(subscribe_to, stage)
2002+
2003+
{:ok, stage, args} ->
2004+
{:ok, stage} = consumer_init_subscribe(subscribe_to, stage)
2005+
{:ok, stage, args}
2006+
2007+
{:stop, _, _} = error ->
2008+
error
2009+
end
19182010
else
19192011
{:error, message} -> {:stop, {:bad_opts, message}}
19202012
end
19212013
end
19222014

2015+
defp handle_gen_server_init_args({:continue, _term} = continue, stage) do
2016+
case handle_continue(continue, stage) do
2017+
{:noreply, stage} ->
2018+
{:ok, stage}
2019+
2020+
{:noreply, stage, :hibernate} ->
2021+
{:ok, stage, :hibernate}
2022+
2023+
{:noreply, stage, {:continue, _term} = continue} ->
2024+
{:ok, stage, continue}
2025+
2026+
{:stop, reason, stage} ->
2027+
{:stop, reason, stage}
2028+
end
2029+
end
2030+
2031+
defp handle_gen_server_init_args(:hibernate, stage), do: {:ok, stage, :hibernate}
2032+
defp handle_gen_server_init_args(nil, stage), do: {:ok, stage}
2033+
2034+
@doc false
2035+
2036+
def handle_continue(continue, %{state: state} = stage) do
2037+
noreply_callback(:handle_continue, [continue, state], stage)
2038+
end
2039+
19232040
@doc false
19242041

19252042
def handle_call({:"$info", msg}, _from, stage) do
@@ -1948,6 +2065,10 @@ defmodule GenStage do
19482065
stage = dispatch_events(events, length(events), %{stage | state: state})
19492066
{:reply, reply, stage, :hibernate}
19502067

2068+
{:reply, reply, events, state, {:continue, _term} = continue} ->
2069+
stage = dispatch_events(events, length(events), %{stage | state: state})
2070+
{:reply, reply, stage, continue}
2071+
19512072
{:stop, reason, reply, state} ->
19522073
{:stop, reason, reply, %{stage | state: state}}
19532074

@@ -2092,7 +2213,7 @@ defmodule GenStage do
20922213
case producers do
20932214
%{^ref => entry} ->
20942215
{batches, stage} = consumer_receive(from, entry, events, stage)
2095-
consumer_dispatch(batches, from, mod, state, stage, false)
2216+
consumer_dispatch(batches, from, mod, state, stage, nil)
20962217

20972218
_ ->
20982219
msg = {:"$gen_producer", {self(), ref}, {:cancel, :unknown_subscription}}
@@ -2219,6 +2340,14 @@ defmodule GenStage do
22192340
end
22202341
end
22212342

2343+
defp noreply_callback(:handle_continue, [continue, state], %{mod: mod} = stage) do
2344+
if function_exported?(mod, :handle_continue, 2) do
2345+
handle_noreply_callback(mod.handle_continue(continue, state), stage)
2346+
else
2347+
:error_handler.raise_undef_exception(mod, :handle_continue, [continue, state])
2348+
end
2349+
end
2350+
22222351
defp noreply_callback(callback, args, %{mod: mod} = stage) do
22232352
handle_noreply_callback(apply(mod, callback, args), stage)
22242353
end
@@ -2233,6 +2362,10 @@ defmodule GenStage do
22332362
stage = dispatch_events(events, length(events), %{stage | state: state})
22342363
{:noreply, stage, :hibernate}
22352364

2365+
{:noreply, events, state, {:continue, _term} = continue} when is_list(events) ->
2366+
stage = dispatch_events(events, length(events), %{stage | state: state})
2367+
{:noreply, stage, continue}
2368+
22362369
{:stop, reason, state} ->
22372370
{:stop, reason, %{stage | state: state}}
22382371

@@ -2364,6 +2497,9 @@ defmodule GenStage do
23642497
# main module must know the consumer is no longer subscribed.
23652498
dispatcher_callback(:cancel, [{pid, ref}, dispatcher_state], stage)
23662499

2500+
{:noreply, %{dispatcher_state: dispatcher_state} = stage, _hibernate_or_continue} ->
2501+
dispatcher_callback(:cancel, [{pid, ref}, dispatcher_state], stage)
2502+
23672503
{:stop, _, _} = stop ->
23682504
stop
23692505
end
@@ -2574,17 +2710,22 @@ defmodule GenStage do
25742710
{[{events, 0}], stage}
25752711
end
25762712

2577-
defp consumer_dispatch([{batch, ask} | batches], from, mod, state, stage, _hibernate?) do
2713+
defp consumer_dispatch([{batch, ask} | batches], from, mod, state, stage, _gen_opts) do
25782714
case mod.handle_events(batch, from, state) do
25792715
{:noreply, events, state} when is_list(events) ->
25802716
stage = dispatch_events(events, length(events), stage)
25812717
ask(from, ask, [:noconnect])
2582-
consumer_dispatch(batches, from, mod, state, stage, false)
2718+
consumer_dispatch(batches, from, mod, state, stage, nil)
25832719

2584-
{:noreply, events, state, :hibernate} when is_list(events) ->
2720+
{:noreply, events, state, :hibernate} ->
25852721
stage = dispatch_events(events, length(events), stage)
25862722
ask(from, ask, [:noconnect])
2587-
consumer_dispatch(batches, from, mod, state, stage, true)
2723+
consumer_dispatch(batches, from, mod, state, stage, :hibernate)
2724+
2725+
{:noreply, events, state, {:continue, _} = continue} ->
2726+
stage = dispatch_events(events, length(events), stage)
2727+
ask(from, ask, [:noconnect])
2728+
consumer_dispatch(batches, from, mod, state, stage, continue)
25882729

25892730
{:stop, reason, state} ->
25902731
{:stop, reason, %{stage | state: state}}
@@ -2594,12 +2735,12 @@ defmodule GenStage do
25942735
end
25952736
end
25962737

2597-
defp consumer_dispatch([], _from, _mod, state, stage, false) do
2738+
defp consumer_dispatch([], _from, _mod, state, stage, nil) do
25982739
{:noreply, %{stage | state: state}}
25992740
end
26002741

2601-
defp consumer_dispatch([], _from, _mod, state, stage, true) do
2602-
{:noreply, %{stage | state: state}, :hibernate}
2742+
defp consumer_dispatch([], _from, _mod, state, stage, gen_opts) do
2743+
{:noreply, %{stage | state: state}, gen_opts}
26032744
end
26042745

26052746
defp consumer_subscribe({to, opts}, stage) when is_list(opts),
@@ -2738,11 +2879,11 @@ defmodule GenStage do
27382879
{producer_id, _, _} = entry
27392880
from = {producer_id, ref}
27402881
{batches, stage} = consumer_receive(from, entry, events, stage)
2741-
consumer_dispatch(batches, from, mod, state, stage, false)
2882+
consumer_dispatch(batches, from, mod, state, stage, nil)
27422883

27432884
%{} ->
27442885
# We queued but producer was removed
2745-
consumer_dispatch([{events, 0}], {:pid, ref}, mod, state, stage, false)
2886+
consumer_dispatch([{events, 0}], {:pid, ref}, mod, state, stage, nil)
27462887
end
27472888
end
27482889

@@ -2759,6 +2900,9 @@ defmodule GenStage do
27592900
{:noreply, stage, :hibernate} ->
27602901
take_pc_events(queue, counter, stage)
27612902

2903+
{:noreply, stage, {:continue, _term}} ->
2904+
take_pc_events(queue, counter, stage)
2905+
27622906
{:stop, _, _} = stop ->
27632907
stop
27642908
end
@@ -2771,6 +2915,9 @@ defmodule GenStage do
27712915
{:noreply, %{events: {queue, counter}} = stage, :hibernate} ->
27722916
take_pc_events(queue, counter, stage)
27732917

2918+
{:noreply, %{events: {queue, counter}} = stage, {:continue, _term}} ->
2919+
take_pc_events(queue, counter, stage)
2920+
27742921
{:stop, _, _} = stop ->
27752922
stop
27762923
end

0 commit comments

Comments
 (0)