Skip to content
This repository was archived by the owner on Mar 5, 2024. It is now read-only.

Commit 8694238

Browse files
committed
Implement random elements selection from a stream
via the optimal reservoir sampling algorithm.
1 parent 87906a3 commit 8694238

File tree

1 file changed

+130
-1
lines changed

1 file changed

+130
-1
lines changed

src/data/data_stream.erl

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,16 @@
1616
lazy_map/2,
1717
lazy_filter/2,
1818
pmap_to_bag/2,
19-
pmap_to_bag/3
19+
pmap_to_bag/3,
20+
random_elements/2
2021
]).
2122

2223
-define(T, ?MODULE).
2324

25+
-type opt(A) :: none | {some, A}.
26+
27+
-type reservoir(A) :: #{pos_integer() => A}.
28+
2429
-type filter(A, B)
2530
:: {map, fun((A) -> B)}
2631
| {test, fun((A) -> boolean())}
@@ -190,8 +195,87 @@ pmap_to_bag(T, F, J) when is_function(F), is_integer(J), J > 0 ->
190195
error({data_stream_scheduler_crashed_before_sending_results, Reason})
191196
end.
192197

198+
-spec random_elements(t(A), non_neg_integer()) -> [A].
199+
random_elements(_, 0) -> [];
200+
random_elements(T, K) when K > 0 ->
201+
{_N, Reservoir} = reservoir_sample(T, #{}, K),
202+
[X || {_, X} <- maps:to_list(Reservoir)].
203+
193204
%% Internal ===================================================================
194205

206+
%% @doc
207+
%% The optimal reservoir sampling algorithm. Known as "Algorithm L" in:
208+
%% https://dl.acm.org/doi/pdf/10.1145/198429.198435
209+
%% https://en.wikipedia.org/wiki/Reservoir_sampling#An_optimal_algorithm
210+
%% @end
211+
-spec reservoir_sample(t(A), reservoir(A), pos_integer()) ->
212+
{pos_integer(), reservoir(A)}.
213+
reservoir_sample(T0, R0, K) ->
214+
case reservoir_sample_init(T0, R0, 1, K) of
215+
{none, R1, I} ->
216+
{I, R1};
217+
{{some, T1}, R1, I} ->
218+
W = random_weight_init(K),
219+
J = random_index_next(I, W),
220+
reservoir_sample_update(T1, R1, W, I, J, K)
221+
end.
222+
223+
-spec reservoir_sample_init(t(A), reservoir(A), pos_integer(), pos_integer()) ->
224+
{opt(t(A)), reservoir(A), pos_integer()}.
225+
reservoir_sample_init(T0, R, I, K) ->
226+
case I > K of
227+
true ->
228+
{{some, T0}, R, I - 1};
229+
false ->
230+
case next(T0) of
231+
{some, {X, T1}} ->
232+
reservoir_sample_init(T1, R#{I => X}, I + 1, K);
233+
none ->
234+
{none, R, I - 1}
235+
end
236+
end.
237+
238+
-spec random_weight_init(pos_integer()) -> float().
239+
random_weight_init(K) ->
240+
math:exp(math:log(rand:uniform()) / K).
241+
242+
-spec random_weight_next(float(), pos_integer()) -> float().
243+
random_weight_next(W, K) ->
244+
W * random_weight_init(K).
245+
246+
-spec random_index_next(pos_integer(), float()) -> pos_integer().
247+
random_index_next(I, W) ->
248+
I + floor(math:log(rand:uniform()) / math:log(1 - W)) + 1.
249+
250+
-spec reservoir_sample_update(
251+
t(A),
252+
reservoir(A),
253+
float(),
254+
pos_integer(),
255+
pos_integer(),
256+
pos_integer()
257+
) ->
258+
{pos_integer(), reservoir(A)}.
259+
reservoir_sample_update(T0, R0, W0, I0, J0, K) ->
260+
case next(T0) of
261+
none ->
262+
{I0, R0};
263+
{some, {X, T1}} ->
264+
I1 = I0 + 1,
265+
case I0 =:= J0 of
266+
true ->
267+
R1 = R0#{rand:uniform(K) => X},
268+
W1 = random_weight_next(W0, K),
269+
J1 = random_index_next(J0, W0),
270+
reservoir_sample_update(T1, R1, W1, I1, J1, K);
271+
false ->
272+
% Here is where the big win takes place over the simple
273+
% Algorithm R. We skip computing random numbers for an
274+
% element that will not be picked.
275+
reservoir_sample_update(T1, R0, W0, I1, J0, K)
276+
end
277+
end.
278+
195279
-spec sched(#sched{}) -> [any()].
196280
sched(#sched{id=_, producers=[], consumers=[], consumers_free=[], work=[], results=Ys}) ->
197281
Ys;
@@ -396,4 +480,49 @@ fold_test_() ->
396480
]
397481
].
398482

483+
random_elements_test_() ->
484+
TestCases =
485+
[
486+
?_assertMatch([a], random_elements(from_list([a]), 1)),
487+
?_assertEqual(0, length(random_elements(from_list([]), 1))),
488+
?_assertEqual(0, length(random_elements(from_list([]), 10))),
489+
?_assertEqual(0, length(random_elements(from_list([]), 100))),
490+
?_assertEqual(1, length(random_elements(from_list(lists:seq(1, 100)), 1))),
491+
?_assertEqual(2, length(random_elements(from_list(lists:seq(1, 100)), 2))),
492+
?_assertEqual(3, length(random_elements(from_list(lists:seq(1, 100)), 3))),
493+
?_assertEqual(5, length(random_elements(from_list(lists:seq(1, 100)), 5)))
494+
|
495+
[
496+
(fun () ->
497+
Trials = 10,
498+
K = floor(N * KF),
499+
L = lists:seq(1, N),
500+
S = from_list(L),
501+
Rands =
502+
[
503+
random_elements(S, K)
504+
||
505+
_ <- lists:duplicate(Trials, {})
506+
],
507+
Head = lists:sublist(L, K),
508+
Unique = lists:usort(Rands) -- [Head],
509+
Name =
510+
lists:flatten(io_lib:format(
511+
"At least 1/~p of trials makes a new sequence. "
512+
"N:~p K:~p KF:~p length(Unique):~p",
513+
[Trials, N, K, KF, length(Unique)]
514+
)),
515+
{Name, ?_assertMatch([_|_], Unique)}
516+
end)()
517+
||
518+
N <- lists:seq(10, 100),
519+
KF <- [
520+
0.25,
521+
0.50,
522+
0.75
523+
]
524+
]
525+
],
526+
{inparallel, TestCases}.
527+
399528
-endif.

0 commit comments

Comments
 (0)