Skip to content

Commit dc3b3f1

Browse files
Merge pull request #1304 from zsxwing/flatMap
Add flatMap and concatMap to RxScala
2 parents 400f69b + 38f96db commit dc3b3f1

File tree

3 files changed

+181
-0
lines changed

3 files changed

+181
-0
lines changed

language-adaptors/rxjava-scala/src/examples/scala/rx/lang/scala/examples/RxScalaDemo.scala

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,4 +1108,72 @@ class RxScalaDemo extends JUnitSuite {
11081108
}
11091109
o.toBlockingObservable.foreach(println(_))
11101110
}
1111+
1112+
@Test def flatMapExample() {
1113+
val o = Observable.items(10, 100)
1114+
o.flatMap(n => Observable.interval(200 millis).map(_ * n))
1115+
.take(20)
1116+
.toBlocking.foreach(println)
1117+
}
1118+
1119+
@Test def flatMapExample2() {
1120+
val o = Observable.items(10, 100)
1121+
val o1 = for (n <- o;
1122+
i <- Observable.interval(200 millis)) yield i * n
1123+
o1.take(20).toBlocking.foreach(println)
1124+
}
1125+
1126+
@Test def flatMapExample3() {
1127+
val o = Observable[Int] {
1128+
subscriber =>
1129+
subscriber.onNext(10)
1130+
subscriber.onNext(100)
1131+
subscriber.onError(new IOException("Oops"))
1132+
}
1133+
o.flatMap(
1134+
(n: Int) => Observable.interval(200 millis).map(_ * n),
1135+
e => Observable.interval(200 millis).map(_ * -1),
1136+
() => Observable.interval(200 millis).map(_ * 1000)
1137+
).take(20)
1138+
.toBlocking.foreach(println)
1139+
}
1140+
1141+
@Test def flatMapExample4() {
1142+
val o = Observable.items(10, 100)
1143+
o.flatMap(
1144+
(n: Int) => Observable.interval(200 millis).map(_ * n),
1145+
e => Observable.interval(200 millis).map(_ * -1),
1146+
() => Observable.interval(200 millis).map(_ * 1000)
1147+
).take(20)
1148+
.toBlocking.foreach(println)
1149+
}
1150+
1151+
@Test def flatMapExample5() {
1152+
val o = Observable.items(1, 10, 100, 1000)
1153+
o.flatMap(
1154+
(n: Int) => Observable.interval(200 millis).take(5),
1155+
(n: Int, m: Long) => n * m
1156+
).toBlocking.foreach(println)
1157+
}
1158+
1159+
@Test def flatMapIterableExample() {
1160+
val o = Observable.items(10, 100)
1161+
o.flatMapIterable(n => (1 to 20).map(_ * n))
1162+
.toBlocking.foreach(println)
1163+
}
1164+
1165+
@Test def flatMapIterableExample2() {
1166+
val o = Observable.items(1, 10, 100, 1000)
1167+
o.flatMapIterable(
1168+
(n: Int) => (1 to 5),
1169+
(n: Int, m: Int) => n * m
1170+
).toBlocking.foreach(println)
1171+
}
1172+
1173+
@Test def concatMapExample() {
1174+
val o = Observable.items(10, 100)
1175+
o.concatMap(n => Observable.interval(200 millis).map(_ * n).take(10))
1176+
.take(20)
1177+
.toBlocking.foreach(println)
1178+
}
11111179
}

language-adaptors/rxjava-scala/src/main/scala/rx/lang/scala/Observable.scala

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,25 @@ trait Observable[+T]
298298
toScalaObservable[U](o5)
299299
}
300300

301+
/**
302+
* Returns a new Observable that emits items resulting from applying a function that you supply to each item
303+
* emitted by the source Observable, where that function returns an Observable, and then emitting the items
304+
* that result from concatinating those resulting Observables.
305+
*
306+
* <img width="640" height="305" src="https://raw.github.com/wiki/Netflix/RxJava/images/rx-operators/concatMap.png">
307+
*
308+
* @param f a function that, when applied to an item emitted by the source Observable, returns an Observable
309+
* @return an Observable that emits the result of applying the transformation function to each item emitted
310+
* by the source Observable and concatinating the Observables obtained from this transformation
311+
*/
312+
def concatMap[R](f: T => Observable[R]): Observable[R] = {
313+
toScalaObservable[R](asJavaObservable.concatMap[R](new Func1[T, rx.Observable[_ <: R]] {
314+
def call(t1: T): rx.Observable[_ <: R] = {
315+
f(t1).asJavaObservable
316+
}
317+
}))
318+
}
319+
301320
/**
302321
* Wraps this Observable in another Observable that ensures that the resulting
303322
* Observable is chronologically well-behaved.
@@ -883,6 +902,95 @@ trait Observable[+T]
883902
}))
884903
}
885904

905+
/**
906+
* Returns an Observable that applies a function to each item emitted or notification raised by the source
907+
* Observable and then flattens the Observables returned from these functions and emits the resulting items.
908+
*
909+
* <img width="640" height="410" src="https://raw.github.com/wiki/Netflix/RxJava/images/rx-operators/mergeMap.nce.png">
910+
*
911+
* @tparam R the result type
912+
* @param onNext a function that returns an Observable to merge for each item emitted by the source Observable
913+
* @param onError a function that returns an Observable to merge for an onError notification from the source
914+
* Observable
915+
* @param onCompleted a function that returns an Observable to merge for an onCompleted notification from the source
916+
* Observable
917+
* @return an Observable that emits the results of merging the Observables returned from applying the
918+
* specified functions to the emissions and notifications of the source Observable
919+
*/
920+
def flatMap[R](onNext: T => Observable[R], onError: Throwable => Observable[R], onCompleted: () => Observable[R]): Observable[R] = {
921+
val jOnNext = new Func1[T, rx.Observable[_ <: R]] {
922+
override def call(t: T): rx.Observable[_ <: R] = onNext(t).asJavaObservable
923+
}
924+
val jOnError = new Func1[Throwable, rx.Observable[_ <: R]] {
925+
override def call(e: Throwable): rx.Observable[_ <: R] = onError(e).asJavaObservable
926+
}
927+
val jOnCompleted = new Func0[rx.Observable[_ <: R]] {
928+
override def call(): rx.Observable[_ <: R] = onCompleted().asJavaObservable
929+
}
930+
toScalaObservable[R](asJavaObservable.mergeMap[R](jOnNext, jOnError, jOnCompleted))
931+
}
932+
933+
/**
934+
* Returns an Observable that emits the results of a specified function to the pair of values emitted by the
935+
* source Observable and a specified collection Observable.
936+
*
937+
* <img width="640" height="390" src="https://raw.github.com/wiki/Netflix/RxJava/images/rx-operators/mergeMap.r.png">
938+
*
939+
* @tparam U the type of items emitted by the collection Observable
940+
* @tparam R the type of items emitted by the resulting Observable
941+
* @param collectionSelector a function that returns an Observable for each item emitted by the source Observable
942+
* @param resultSelector a function that combines one item emitted by each of the source and collection Observables and
943+
* returns an item to be emitted by the resulting Observable
944+
* @return an Observable that emits the results of applying a function to a pair of values emitted by the
945+
* source Observable and the collection Observable
946+
*/
947+
def flatMap[U, R](collectionSelector: T => Observable[U], resultSelector: (T, U) => R): Observable[R] = {
948+
val jCollectionSelector = new Func1[T, rx.Observable[_ <: U]] {
949+
override def call(t: T): rx.Observable[_ <: U] = collectionSelector(t).asJavaObservable
950+
}
951+
toScalaObservable[R](asJavaObservable.mergeMap[U, R](jCollectionSelector, resultSelector))
952+
}
953+
954+
/**
955+
* Returns an Observable that merges each item emitted by the source Observable with the values in an
956+
* Iterable corresponding to that item that is generated by a selector.
957+
*
958+
* <img width="640" height="310" src="https://raw.github.com/wiki/Netflix/RxJava/images/rx-operators/mergeMapIterable.png">
959+
*
960+
* @tparam R the type of item emitted by the resulting Observable
961+
* @param collectionSelector a function that returns an Iterable sequence of values for when given an item emitted by the
962+
* source Observable
963+
* @return an Observable that emits the results of merging the items emitted by the source Observable with
964+
* the values in the Iterables corresponding to those items, as generated by `collectionSelector
965+
*/
966+
def flatMapIterable[R](collectionSelector: T => Iterable[R]): Observable[R] = {
967+
val jCollectionSelector = new Func1[T, java.lang.Iterable[_ <: R]] {
968+
override def call(t: T): java.lang.Iterable[_ <: R] = collectionSelector(t).asJava
969+
}
970+
toScalaObservable[R](asJavaObservable.mergeMapIterable[R](jCollectionSelector))
971+
}
972+
973+
/**
974+
* Returns an Observable that emits the results of applying a function to the pair of values from the source
975+
* Observable and an Iterable corresponding to that item that is generated by a selector.
976+
*
977+
* <img width="640" height="390" src="https://raw.github.com/wiki/Netflix/RxJava/images/rx-operators/mergeMapIterable.r.png">
978+
*
979+
* @tparam U the collection element type
980+
* @tparam R the type of item emited by the resulting Observable
981+
* @param collectionSelector a function that returns an Iterable sequence of values for each item emitted by the source
982+
* Observable
983+
* @param resultSelector a function that returns an item based on the item emitted by the source Observable and the
984+
* Iterable returned for that item by the `collectionSelector`
985+
* @return an Observable that emits the items returned by `resultSelector` for each item in the source Observable
986+
*/
987+
def flatMapIterable[U, R](collectionSelector: T => Iterable[U], resultSelector: (T, U) => R): Observable[R] = {
988+
val jCollectionSelector = new Func1[T, java.lang.Iterable[_ <: U]] {
989+
override def call(t: T): java.lang.Iterable[_ <: U] = collectionSelector(t).asJava
990+
}
991+
toScalaObservable[R](asJavaObservable.mergeMapIterable[U, R](jCollectionSelector, resultSelector))
992+
}
993+
886994
/**
887995
* Returns an Observable that applies the given function to each item emitted by an
888996
* Observable and emits the result.

language-adaptors/rxjava-scala/src/test/scala/rx/lang/scala/CompletenessTest.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ class CompletenessTest extends JUnitSuite {
9595
"lift(Operator[_ <: R, _ >: T])" -> "lift(Subscriber[R] => Subscriber[T])",
9696
"limit(Int)" -> "take(Int)",
9797
"mapWithIndex(Func2[_ >: T, Integer, _ <: R])" -> "[combine `zipWithIndex` with `map` or with a for comprehension]",
98+
"mergeMap(Func1[_ >: T, _ <: Observable[_ <: R]])" -> "flatMap(T => Observable[R])",
99+
"mergeMap(Func1[_ >: T, _ <: Observable[_ <: R]], Func1[_ >: Throwable, _ <: Observable[_ <: R]], Func0[_ <: Observable[_ <: R]])" -> "flatMap(T => Observable[R], Throwable => Observable[R], () => Observable[R])",
100+
"mergeMap(Func1[_ >: T, _ <: Observable[_ <: U]], Func2[_ >: T, _ >: U, _ <: R])" -> "flatMap(T => Observable[U], (T, U) => R)",
101+
"mergeMapIterable(Func1[_ >: T, _ <: Iterable[_ <: R]])" -> "flatMapIterable(T => Iterable[R])",
102+
"mergeMapIterable(Func1[_ >: T, _ <: Iterable[_ <: U]], Func2[_ >: T, _ >: U, _ <: R])" -> "flatMapIterable(T => Iterable[U], (T, U) => R)",
98103
"multicast(Subject[_ >: T, _ <: R])" -> "multicast(Subject[R])",
99104
"multicast(Func0[_ <: Subject[_ >: T, _ <: TIntermediate]], Func1[_ >: Observable[TIntermediate], _ <: Observable[TResult]])" -> "multicast(() => Subject[R], Observable[R] => Observable[U])",
100105
"onErrorResumeNext(Func1[Throwable, _ <: Observable[_ <: T]])" -> "onErrorResumeNext(Throwable => Observable[U])",

0 commit comments

Comments
 (0)