Skip to content

Commit c5691fb

Browse files
committed
Add flatMap variants to RxScala
1 parent 400f69b commit c5691fb

File tree

3 files changed

+155
-0
lines changed

3 files changed

+155
-0
lines changed

language-adaptors/rxjava-scala/src/examples/scala/rx/lang/scala/examples/RxScalaDemo.scala

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,4 +1108,65 @@ class RxScalaDemo extends JUnitSuite {
11081108
}
11091109
o.toBlockingObservable.foreach(println(_))
11101110
}
1111+
1112+
@Test def flatMapExample() {
1113+
val o = Observable.items(10, 100)
1114+
o.flatMap(n => Observable.interval(200 millis).map(_ * n))
1115+
.take(20)
1116+
.toBlocking.foreach(println)
1117+
}
1118+
1119+
@Test def flatMapExample2() {
1120+
val o = Observable.items(10, 100)
1121+
val o1 = for (n <- o;
1122+
i <- Observable.interval(200 millis)) yield i * n
1123+
o1.take(20).toBlocking.foreach(println)
1124+
}
1125+
1126+
@Test def flatMapExample3() {
1127+
val o = Observable[Int] {
1128+
subscriber =>
1129+
subscriber.onNext(10)
1130+
subscriber.onNext(100)
1131+
subscriber.onError(new IOException("Oops"))
1132+
}
1133+
o.flatMap(
1134+
(n: Int) => Observable.interval(200 millis).map(_ * n),
1135+
e => Observable.interval(200 millis).map(_ * -1),
1136+
() => Observable.interval(200 millis).map(_ * 1000)
1137+
).take(20)
1138+
.toBlocking.foreach(println)
1139+
}
1140+
1141+
@Test def flatMapExample4() {
1142+
val o = Observable.items(10, 100)
1143+
o.flatMap(
1144+
(n: Int) => Observable.interval(200 millis).map(_ * n),
1145+
e => Observable.interval(200 millis).map(_ * -1),
1146+
() => Observable.interval(200 millis).map(_ * 1000)
1147+
).take(20)
1148+
.toBlocking.foreach(println)
1149+
}
1150+
1151+
@Test def flatMapExample5() {
1152+
val o = Observable.items(1, 10, 100, 1000)
1153+
o.flatMap(
1154+
(n: Int) => Observable.interval(200 millis).take(5),
1155+
(n: Int, m: Long) => n * m
1156+
).toBlocking.foreach(println)
1157+
}
1158+
1159+
@Test def flatMapIterableExample() {
1160+
val o = Observable.items(10, 100)
1161+
o.flatMapIterable(n => (1 to 20).map(_ * n))
1162+
.toBlocking.foreach(println)
1163+
}
1164+
1165+
@Test def flatMapIterableExample2() {
1166+
val o = Observable.items(1, 10, 100, 1000)
1167+
o.flatMapIterable(
1168+
(n: Int) => (1 to 5),
1169+
(n: Int, m: Int) => n * m
1170+
).toBlocking.foreach(println)
1171+
}
11111172
}

language-adaptors/rxjava-scala/src/main/scala/rx/lang/scala/Observable.scala

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,95 @@ trait Observable[+T]
883883
}))
884884
}
885885

886+
/**
887+
* Returns an Observable that applies a function to each item emitted or notification raised by the source
888+
* Observable and then flattens the Observables returned from these functions and emits the resulting items.
889+
*
890+
* <img width="640" height="410" src="https://raw.github.com/wiki/Netflix/RxJava/images/rx-operators/mergeMap.nce.png">
891+
*
892+
* @tparam R the result type
893+
* @param onNext a function that returns an Observable to merge for each item emitted by the source Observable
894+
* @param onError a function that returns an Observable to merge for an onError notification from the source
895+
* Observable
896+
* @param onCompleted a function that returns an Observable to merge for an onCompleted notification from the source
897+
* Observable
898+
* @return an Observable that emits the results of merging the Observables returned from applying the
899+
* specified functions to the emissions and notifications of the source Observable
900+
*/
901+
def flatMap[R](onNext: T => Observable[R], onError: Throwable => Observable[R], onCompleted: () => Observable[R]): Observable[R] = {
902+
val jOnNext = new Func1[T, rx.Observable[_ <: R]] {
903+
override def call(t: T): rx.Observable[_ <: R] = onNext(t).asJavaObservable
904+
}
905+
val jOnError = new Func1[Throwable, rx.Observable[_ <: R]] {
906+
override def call(e: Throwable): rx.Observable[_ <: R] = onError(e).asJavaObservable
907+
}
908+
val jOnCompleted = new Func0[rx.Observable[_ <: R]] {
909+
override def call(): rx.Observable[_ <: R] = onCompleted().asJavaObservable
910+
}
911+
toScalaObservable[R](asJavaObservable.mergeMap[R](jOnNext, jOnError, jOnCompleted))
912+
}
913+
914+
/**
915+
* Returns an Observable that emits the results of a specified function to the pair of values emitted by the
916+
* source Observable and a specified collection Observable.
917+
*
918+
* <img width="640" height="390" src="https://raw.github.com/wiki/Netflix/RxJava/images/rx-operators/mergeMap.r.png">
919+
*
920+
* @tparam U the type of items emitted by the collection Observable
921+
* @tparam R the type of items emitted by the resulting Observable
922+
* @param collectionSelector a function that returns an Observable for each item emitted by the source Observable
923+
* @param resultSelector a function that combines one item emitted by each of the source and collection Observables and
924+
* returns an item to be emitted by the resulting Observable
925+
* @return an Observable that emits the results of applying a function to a pair of values emitted by the
926+
* source Observable and the collection Observable
927+
*/
928+
def flatMap[U, R](collectionSelector: T => Observable[U], resultSelector: (T, U) => R): Observable[R] = {
929+
val jCollectionSelector = new Func1[T, rx.Observable[_ <: U]] {
930+
override def call(t: T): rx.Observable[_ <: U] = collectionSelector(t).asJavaObservable
931+
}
932+
toScalaObservable[R](asJavaObservable.mergeMap[U, R](jCollectionSelector, resultSelector))
933+
}
934+
935+
/**
936+
* Returns an Observable that merges each item emitted by the source Observable with the values in an
937+
* Iterable corresponding to that item that is generated by a selector.
938+
*
939+
* <img width="640" height="310" src="https://raw.github.com/wiki/Netflix/RxJava/images/rx-operators/mergeMapIterable.png">
940+
*
941+
* @tparam R the type of item emitted by the resulting Observable
942+
* @param collectionSelector a function that returns an Iterable sequence of values for when given an item emitted by the
943+
* source Observable
944+
* @return an Observable that emits the results of merging the items emitted by the source Observable with
945+
* the values in the Iterables corresponding to those items, as generated by `collectionSelector
946+
*/
947+
def flatMapIterable[R](collectionSelector: T => Iterable[R]): Observable[R] = {
948+
val jCollectionSelector = new Func1[T, java.lang.Iterable[_ <: R]] {
949+
override def call(t: T): java.lang.Iterable[_ <: R] = collectionSelector(t).asJava
950+
}
951+
toScalaObservable[R](asJavaObservable.mergeMapIterable[R](jCollectionSelector))
952+
}
953+
954+
/**
955+
* Returns an Observable that emits the results of applying a function to the pair of values from the source
956+
* Observable and an Iterable corresponding to that item that is generated by a selector.
957+
*
958+
* <img width="640" height="390" src="https://raw.github.com/wiki/Netflix/RxJava/images/rx-operators/mergeMapIterable.r.png">
959+
*
960+
* @tparam U the collection element type
961+
* @tparam R the type of item emited by the resulting Observable
962+
* @param collectionSelector a function that returns an Iterable sequence of values for each item emitted by the source
963+
* Observable
964+
* @param resultSelector a function that returns an item based on the item emitted by the source Observable and the
965+
* Iterable returned for that item by the `collectionSelector`
966+
* @return an Observable that emits the items returned by `resultSelector` for each item in the source Observable
967+
*/
968+
def flatMapIterable[U, R](collectionSelector: T => Iterable[U], resultSelector: (T, U) => R): Observable[R] = {
969+
val jCollectionSelector = new Func1[T, java.lang.Iterable[_ <: U]] {
970+
override def call(t: T): java.lang.Iterable[_ <: U] = collectionSelector(t).asJava
971+
}
972+
toScalaObservable[R](asJavaObservable.mergeMapIterable[U, R](jCollectionSelector, resultSelector))
973+
}
974+
886975
/**
887976
* Returns an Observable that applies the given function to each item emitted by an
888977
* Observable and emits the result.

language-adaptors/rxjava-scala/src/test/scala/rx/lang/scala/CompletenessTest.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ class CompletenessTest extends JUnitSuite {
9595
"lift(Operator[_ <: R, _ >: T])" -> "lift(Subscriber[R] => Subscriber[T])",
9696
"limit(Int)" -> "take(Int)",
9797
"mapWithIndex(Func2[_ >: T, Integer, _ <: R])" -> "[combine `zipWithIndex` with `map` or with a for comprehension]",
98+
"mergeMap(Func1[_ >: T, _ <: Observable[_ <: R]])" -> "flatMap(T => Observable[R])",
99+
"mergeMap(Func1[_ >: T, _ <: Observable[_ <: R]], Func1[_ >: Throwable, _ <: Observable[_ <: R]], Func0[_ <: Observable[_ <: R]])" -> "flatMap(T => Observable[R], Throwable => Observable[R], () => Observable[R])",
100+
"mergeMap(Func1[_ >: T, _ <: Observable[_ <: U]], Func2[_ >: T, _ >: U, _ <: R])" -> "flatMap(T => Observable[U], (T, U) => R)",
101+
"mergeMapIterable(Func1[_ >: T, _ <: Iterable[_ <: R]])" -> "flatMapIterable(T => Iterable[R])",
102+
"mergeMapIterable(Func1[_ >: T, _ <: Iterable[_ <: U]], Func2[_ >: T, _ >: U, _ <: R])" -> "flatMapIterable(T => Iterable[U], (T, U) => R)",
98103
"multicast(Subject[_ >: T, _ <: R])" -> "multicast(Subject[R])",
99104
"multicast(Func0[_ <: Subject[_ >: T, _ <: TIntermediate]], Func1[_ >: Observable[TIntermediate], _ <: Observable[TResult]])" -> "multicast(() => Subject[R], Observable[R] => Observable[U])",
100105
"onErrorResumeNext(Func1[Throwable, _ <: Observable[_ <: T]])" -> "onErrorResumeNext(Throwable => Observable[U])",

0 commit comments

Comments
 (0)