1717
1818import static org .junit .Assert .*;
1919
20- import java .util .ArrayList ;
2120import java .util .Arrays ;
22- import java .util .HashMap ;
23- import java .util .List ;
21+ import java .util .Collection ;
2422import java .util .Map ;
2523import java .util .concurrent .ConcurrentHashMap ;
24+ import java .util .concurrent .ConcurrentLinkedQueue ;
25+ import java .util .concurrent .CountDownLatch ;
26+ import java .util .concurrent .TimeUnit ;
27+ import java .util .concurrent .atomic .AtomicInteger ;
28+ import java .util .concurrent .atomic .AtomicReference ;
2629
2730import org .junit .Test ;
2831
2932import rx .Observable ;
3033import rx .Observer ;
3134import rx .Subscription ;
3235import rx .observables .GroupedObservable ;
36+ import rx .subscriptions .Subscriptions ;
37+ import rx .util .functions .Action1 ;
3338import rx .util .functions .Func1 ;
3439import rx .util .functions .Functions ;
3540
@@ -55,69 +60,137 @@ public static <K, T> Func1<Observer<GroupedObservable<K, T>>, Subscription> grou
5560 }
5661
5762 private static class GroupBy <K , V > implements Func1 <Observer <GroupedObservable <K , V >>, Subscription > {
63+
5864 private final Observable <KeyValue <K , V >> source ;
65+ private final ConcurrentHashMap <K , GroupedSubject <K , V >> groupedObservables = new ConcurrentHashMap <K , GroupedSubject <K , V >>();
5966
6067 private GroupBy (Observable <KeyValue <K , V >> source ) {
6168 this .source = source ;
6269 }
6370
6471 @ Override
6572 public Subscription call (final Observer <GroupedObservable <K , V >> observer ) {
66- return source .subscribe (new GroupByObserver (observer ));
73+ return source .subscribe (new Observer <KeyValue <K , V >>() {
74+
75+ @ Override
76+ public void onCompleted () {
77+ // we need to propagate to all children I imagine ... we can't just leave all of those Observable/Observers hanging
78+ for (GroupedSubject <K , V > o : groupedObservables .values ()) {
79+ o .onCompleted ();
80+ }
81+ // now the parent
82+ observer .onCompleted ();
83+ }
84+
85+ @ Override
86+ public void onError (Exception e ) {
87+ // we need to propagate to all children I imagine ... we can't just leave all of those Observable/Observers hanging
88+ for (GroupedSubject <K , V > o : groupedObservables .values ()) {
89+ o .onError (e );
90+ }
91+ // now the parent
92+ observer .onError (e );
93+ }
94+
95+ @ Override
96+ public void onNext (KeyValue <K , V > value ) {
97+ GroupedSubject <K , V > gs = groupedObservables .get (value .key );
98+ if (gs == null ) {
99+ /*
100+ * Technically the source should be single-threaded so we shouldn't need to do this but I am
101+ * programming defensively as most operators are so this can work with a concurrent sequence
102+ * if it ends up receiving one.
103+ */
104+ GroupedSubject <K , V > newGs = GroupedSubject .<K , V > create (value .key );
105+ GroupedSubject <K , V > existing = groupedObservables .putIfAbsent (value .key , newGs );
106+ if (existing == null ) {
107+ // we won so use the one we created
108+ gs = newGs ;
109+ // since we won the creation we emit this new GroupedObservable
110+ observer .onNext (gs );
111+ } else {
112+ // another thread beat us so use the existing one
113+ gs = existing ;
114+ }
115+ }
116+ gs .onNext (value .value );
117+ }
118+ });
67119 }
120+ }
68121
69- private class GroupByObserver implements Observer <KeyValue <K , V >> {
70- private final Observer <GroupedObservable <K , V >> underlying ;
122+ private static class GroupedSubject <K , T > extends GroupedObservable <K , T > implements Observer <T > {
71123
72- private final ConcurrentHashMap <K , Boolean > keys = new ConcurrentHashMap <K , Boolean >();
124+ static <K , T > GroupedSubject <K , T > create (K key ) {
125+ @ SuppressWarnings ("unchecked" )
126+ final AtomicReference <Observer <T >> subscribedObserver = new AtomicReference <Observer <T >>(EMPTY_OBSERVER );
73127
74- private GroupByObserver (Observer <GroupedObservable <K , V >> underlying ) {
75- this .underlying = underlying ;
76- }
128+ return new GroupedSubject <K , T >(key , new Func1 <Observer <T >, Subscription >() {
77129
78- @ Override
79- public void onCompleted ( ) {
80- underlying . onCompleted ();
81- }
130+ @ Override
131+ public Subscription call ( Observer < T > observer ) {
132+ // register Observer
133+ subscribedObserver . set ( observer );
82134
83- @ Override
84- public void onError (Exception e ) {
85- underlying .onError (e );
86- }
135+ return new Subscription () {
87136
88- @ Override
89- public void onNext (final KeyValue <K , V > args ) {
90- K key = args .key ;
91- boolean newGroup = keys .putIfAbsent (key , true ) == null ;
92- if (newGroup ) {
93- underlying .onNext (buildObservableFor (source , key ));
137+ @ SuppressWarnings ("unchecked" )
138+ @ Override
139+ public void unsubscribe () {
140+ // we remove the Observer so we stop emitting further events (they will be ignored if parent continues to send)
141+ subscribedObserver .set (EMPTY_OBSERVER );
142+ // I don't believe we need to worry about the parent here as it's a separate sequence that would
143+ // be unsubscribed to directly if that needs to happen.
144+ }
145+ };
94146 }
95- }
147+ }, subscribedObserver );
96148 }
97- }
98149
99- private static <K , R > GroupedObservable <K , R > buildObservableFor (Observable <KeyValue <K , R >> source , final K key ) {
100- final Observable <R > observable = source .filter (new Func1 <KeyValue <K , R >, Boolean >() {
101- @ Override
102- public Boolean call (KeyValue <K , R > pair ) {
103- return key .equals (pair .key );
104- }
105- }).map (new Func1 <KeyValue <K , R >, R >() {
106- @ Override
107- public R call (KeyValue <K , R > pair ) {
108- return pair .value ;
109- }
110- });
111- return new GroupedObservable <K , R >(key , new Func1 <Observer <R >, Subscription >() {
150+ private final AtomicReference <Observer <T >> subscribedObserver ;
112151
113- @ Override
114- public Subscription call (Observer <R > observer ) {
115- return observable .subscribe (observer );
116- }
152+ public GroupedSubject (K key , Func1 <Observer <T >, Subscription > onSubscribe , AtomicReference <Observer <T >> subscribedObserver ) {
153+ super (key , onSubscribe );
154+ this .subscribedObserver = subscribedObserver ;
155+ }
156+
157+ @ Override
158+ public void onCompleted () {
159+ subscribedObserver .get ().onCompleted ();
160+ }
161+
162+ @ Override
163+ public void onError (Exception e ) {
164+ subscribedObserver .get ().onError (e );
165+ }
166+
167+ @ Override
168+ public void onNext (T v ) {
169+ subscribedObserver .get ().onNext (v );
170+ }
117171
118- });
119172 }
120173
174+ @ SuppressWarnings ("rawtypes" )
175+ private static Observer EMPTY_OBSERVER = new Observer () {
176+
177+ @ Override
178+ public void onCompleted () {
179+ // do nothing
180+ }
181+
182+ @ Override
183+ public void onError (Exception e ) {
184+ // do nothing
185+ }
186+
187+ @ Override
188+ public void onNext (Object args ) {
189+ // do nothing
190+ }
191+
192+ };
193+
121194 private static class KeyValue <K , V > {
122195 private final K key ;
123196 private final V value ;
@@ -141,45 +214,146 @@ public void testGroupBy() {
141214 Observable <String > source = Observable .from ("one" , "two" , "three" , "four" , "five" , "six" );
142215 Observable <GroupedObservable <Integer , String >> grouped = Observable .create (groupBy (source , length ));
143216
144- Map <Integer , List <String >> map = toMap (grouped );
217+ Map <Integer , Collection <String >> map = toMap (grouped );
145218
146219 assertEquals (3 , map .size ());
147- assertEquals (Arrays .asList ("one" , "two" , "six" ), map .get (3 ));
148- assertEquals (Arrays .asList ("four" , "five" ), map .get (4 ));
149- assertEquals (Arrays .asList ("three" ), map .get (5 ));
150-
220+ assertArrayEquals (Arrays .asList ("one" , "two" , "six" ).toArray (), map .get (3 ).toArray ());
221+ assertArrayEquals (Arrays .asList ("four" , "five" ).toArray (), map .get (4 ).toArray ());
222+ assertArrayEquals (Arrays .asList ("three" ).toArray (), map .get (5 ).toArray ());
151223 }
152224
153225 @ Test
154226 public void testEmpty () {
155227 Observable <String > source = Observable .from ();
156228 Observable <GroupedObservable <Integer , String >> grouped = Observable .create (groupBy (source , length ));
157229
158- Map <Integer , List <String >> map = toMap (grouped );
230+ Map <Integer , Collection <String >> map = toMap (grouped );
159231
160232 assertTrue (map .isEmpty ());
161233 }
162234
163- private static <K , V > Map <K , List <V >> toMap (Observable <GroupedObservable <K , V >> observable ) {
164- Map <K , List <V >> result = new HashMap <K , List <V >>();
165- for (GroupedObservable <K , V > g : observable .toBlockingObservable ().toIterable ()) {
166- K key = g .getKey ();
235+ private static <K , V > Map <K , Collection <V >> toMap (Observable <GroupedObservable <K , V >> observable ) {
167236
168- for (V value : g .toBlockingObservable ().toIterable ()) {
169- List <V > values = result .get (key );
170- if (values == null ) {
171- values = new ArrayList <V >();
172- result .put (key , values );
173- }
237+ final ConcurrentHashMap <K , Collection <V >> result = new ConcurrentHashMap <K , Collection <V >>();
174238
175- values .add (value );
176- }
239+ observable .forEach (new Action1 <GroupedObservable <K , V >>() {
177240
178- }
241+ @ Override
242+ public void call (final GroupedObservable <K , V > o ) {
243+ result .put (o .getKey (), new ConcurrentLinkedQueue <V >());
244+ o .subscribe (new Action1 <V >() {
245+
246+ @ Override
247+ public void call (V v ) {
248+ result .get (o .getKey ()).add (v );
249+ }
250+
251+ });
252+ }
253+ });
179254
180255 return result ;
181256 }
182257
258+ /**
259+ * Assert that only a single subscription to a stream occurs and that all events are received.
260+ *
261+ * @throws Exception
262+ */
263+ @ Test
264+ public void testGroupedEventStream () throws Exception {
265+
266+ final AtomicInteger eventCounter = new AtomicInteger ();
267+ final AtomicInteger subscribeCounter = new AtomicInteger ();
268+ final AtomicInteger groupCounter = new AtomicInteger ();
269+ final CountDownLatch latch = new CountDownLatch (1 );
270+ final int count = 100 ;
271+ final int groupCount = 2 ;
272+
273+ Observable <Event > es = Observable .create (new Func1 <Observer <Event >, Subscription >() {
274+
275+ @ Override
276+ public Subscription call (final Observer <Event > observer ) {
277+ System .out .println ("*** Subscribing to EventStream ***" );
278+ subscribeCounter .incrementAndGet ();
279+ new Thread (new Runnable () {
280+
281+ @ Override
282+ public void run () {
283+ for (int i = 0 ; i < count ; i ++) {
284+ Event e = new Event ();
285+ e .source = i % groupCount ;
286+ e .message = "Event-" + i ;
287+ observer .onNext (e );
288+ }
289+ observer .onCompleted ();
290+ }
291+
292+ }).start ();
293+ return Subscriptions .empty ();
294+ }
295+
296+ });
297+
298+ es .groupBy (new Func1 <Event , Integer >() {
299+
300+ @ Override
301+ public Integer call (Event e ) {
302+ return e .source ;
303+ }
304+ }).mapMany (new Func1 <GroupedObservable <Integer , Event >, Observable <String >>() {
305+
306+ @ Override
307+ public Observable <String > call (GroupedObservable <Integer , Event > eventGroupedObservable ) {
308+ System .out .println ("GroupedObservable Key: " + eventGroupedObservable .getKey ());
309+ groupCounter .incrementAndGet ();
310+
311+ return eventGroupedObservable .map (new Func1 <Event , String >() {
312+
313+ @ Override
314+ public String call (Event event ) {
315+ return "Source: " + event .source + " Message: " + event .message ;
316+ }
317+ });
318+
319+ };
320+ }).subscribe (new Observer <String >() {
321+
322+ @ Override
323+ public void onCompleted () {
324+ latch .countDown ();
325+ }
326+
327+ @ Override
328+ public void onError (Exception e ) {
329+ e .printStackTrace ();
330+ latch .countDown ();
331+ }
332+
333+ @ Override
334+ public void onNext (String outputMessage ) {
335+ System .out .println (outputMessage );
336+ eventCounter .incrementAndGet ();
337+ }
338+ });
339+
340+ latch .await (5000 , TimeUnit .MILLISECONDS );
341+ assertEquals (1 , subscribeCounter .get ());
342+ assertEquals (groupCount , groupCounter .get ());
343+ assertEquals (count , eventCounter .get ());
344+
345+ }
346+
347+ private static class Event {
348+ int source ;
349+ String message ;
350+
351+ @ Override
352+ public String toString () {
353+ return "Event => source: " + source + " message: " + message ;
354+ }
355+ }
356+
183357 }
184358
185359}
0 commit comments