Skip to content

Commit f5c00d7

Browse files
tkromanrozza
authored andcommitted
Fix race condition in FlatMapObservable completion handler
JAVA-4241
1 parent 05fdd7e commit f5c00d7

File tree

2 files changed

+93
-28
lines changed

2 files changed

+93
-28
lines changed

driver-scala/src/main/scala/org/mongodb/scala/internal/FlatMapObservable.scala

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,72 +18,99 @@ package org.mongodb.scala.internal
1818

1919
import org.mongodb.scala._
2020

21+
import java.util.concurrent.atomic.AtomicReference
22+
23+
sealed trait State
24+
case object Init extends State
25+
case class WaitingOnChild(s: Subscription) extends State
26+
case object LastChildNotified extends State
27+
case object LastChildResponded extends State
28+
case object Done extends State
29+
case object Error extends State
30+
2131
private[scala] case class FlatMapObservable[T, S](observable: Observable[T], f: T => Observable[S])
2232
extends Observable[S] {
23-
2433
// scalastyle:off cyclomatic.complexity method.length
2534
override def subscribe(observer: Observer[_ >: S]): Unit = {
2635
observable.subscribe(
2736
SubscriptionCheckingObserver(
2837
new Observer[T] {
29-
30-
@volatile
31-
private var outerSubscription: Option[Subscription] = None
32-
@volatile
33-
private var nestedSubscription: Option[Subscription] = None
34-
@volatile
35-
private var demand: Long = 0
36-
@volatile
37-
private var onCompleteCalled: Boolean = false
38+
@volatile private var outerSubscription: Option[Subscription] = None
39+
@volatile private var demand: Long = 0
40+
private val state = new AtomicReference[State](Init)
3841

3942
override def onSubscribe(subscription: Subscription): Unit = {
4043
val masterSub = new Subscription() {
4144
override def isUnsubscribed: Boolean = subscription.isUnsubscribed
42-
43-
def request(n: Long): Unit = {
45+
override def unsubscribe(): Unit = subscription.unsubscribe()
46+
override def request(n: Long): Unit = {
4447
require(n > 0L, s"Number requested must be greater than zero: $n")
4548
val localDemand = addDemand(n)
46-
val (sub, num) = nestedSubscription.map((_, localDemand)).getOrElse((subscription, 1L))
47-
sub.request(num)
49+
state.get() match {
50+
case Init => subscription.request(1L)
51+
case WaitingOnChild(s) => s.request(localDemand)
52+
case _ => // noop
53+
}
4854
}
49-
50-
override def unsubscribe(): Unit = subscription.unsubscribe()
5155
}
52-
5356
outerSubscription = Some(masterSub)
57+
state.set(Init)
5458
observer.onSubscribe(masterSub)
5559
}
5660

5761
override def onComplete(): Unit = {
58-
if (!onCompleteCalled) {
59-
onCompleteCalled = true
60-
if (nestedSubscription.isEmpty) observer.onComplete()
62+
state.get() match {
63+
case Done => // ok
64+
case Error => // ok
65+
case Init if state.compareAndSet(Init, Done) =>
66+
observer.onComplete()
67+
case w @ WaitingOnChild(_) if state.compareAndSet(w, LastChildNotified) =>
68+
// letting the child know that we delegate onComplete call to it
69+
case LastChildNotified =>
70+
// wait for the child to do the delegated onCompleteCall
71+
case LastChildResponded if state.compareAndSet(LastChildResponded, Done) =>
72+
observer.onComplete()
73+
case other =>
74+
// state machine is broken, let's fail
75+
// normally this won't happen
76+
throw new IllegalStateException(s"Unexpected state in FlatMapObservable `onComplete` handler: ${other}")
6177
}
6278
}
6379

64-
override def onError(throwable: Throwable): Unit = observer.onError(throwable)
80+
override def onError(throwable: Throwable): Unit = {
81+
observer.onError(throwable)
82+
}
6583

6684
override def onNext(tResult: T): Unit = {
6785
f(tResult).subscribe(
6886
new Observer[S]() {
6987
override def onError(throwable: Throwable): Unit = {
70-
nestedSubscription = None
88+
state.set(Error)
7189
observer.onError(throwable)
7290
}
7391

7492
override def onSubscribe(subscription: Subscription): Unit = {
75-
nestedSubscription = Some(subscription)
93+
state.set(WaitingOnChild(subscription))
7694
if (demand > 0) subscription.request(demand)
7795
}
7896

7997
override def onComplete(): Unit = {
80-
nestedSubscription = None
81-
onCompleteCalled match {
82-
case true => observer.onComplete()
83-
case false if demand > 0 =>
98+
state.get() match {
99+
case Done => // no need to call parent's onComplete
100+
case Error => // no need to call parent's onComplete
101+
case LastChildNotified if state.compareAndSet(LastChildNotified, LastChildResponded) =>
102+
// parent told us to call onComplete
103+
observer.onComplete()
104+
case _ if demand > 0 =>
105+
// otherwise we are not the last child, let's tell the parent
106+
// it's not dealing with us anymore.
107+
// Init -> * will be handled by possible later items in the stream
108+
state.set(Init)
84109
addDemand(-1) // reduce demand by 1 as it will be incremented by the outerSubscription
85110
outerSubscription.foreach(_.request(1))
86-
case false => // No more demand
111+
case _ =>
112+
// no demand
113+
state.set(Init)
87114
}
88115
}
89116

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package org.mongodb.scala.internal
2+
3+
import org.mongodb.scala.{ BaseSpec, Observable, Observer }
4+
import org.scalatest.concurrent.{ Eventually, Futures }
5+
6+
import java.util.concurrent.atomic.AtomicInteger
7+
import scala.concurrent.ExecutionContext.Implicits.global
8+
import scala.concurrent.{ Future, Promise }
9+
10+
class FlatMapObservableTest extends BaseSpec with Futures with Eventually {
11+
"FlatMapObservable" should "only complete once" in {
12+
val p = Promise[Unit]()
13+
val completedCounter = new AtomicInteger(0)
14+
Observable(1 to 100)
15+
.flatMap(
16+
x =>
17+
(observer: Observer[_ >: Int]) => {
18+
Future(()).onComplete(_ => {
19+
observer.onNext(x)
20+
observer.onComplete()
21+
})
22+
}
23+
)
24+
.subscribe(
25+
_ => (),
26+
p.failure,
27+
() => {
28+
completedCounter.incrementAndGet()
29+
Thread.sleep(100)
30+
p.trySuccess(())
31+
}
32+
)
33+
eventually(assert(completedCounter.get() == 1, s"${completedCounter.get()}"))
34+
Thread.sleep(200)
35+
assert(completedCounter.get() == 1, s"${completedCounter.get()}")
36+
Thread.sleep(1000)
37+
}
38+
}

0 commit comments

Comments
 (0)