@@ -18,72 +18,99 @@ package org.mongodb.scala.internal
18
18
19
19
import org .mongodb .scala ._
20
20
21
+ import java .util .concurrent .atomic .AtomicReference
22
+
23
+ sealed trait State
24
+ case object Init extends State
25
+ case class WaitingOnChild (s : Subscription ) extends State
26
+ case object LastChildNotified extends State
27
+ case object LastChildResponded extends State
28
+ case object Done extends State
29
+ case object Error extends State
30
+
21
31
private [scala] case class FlatMapObservable [T , S ](observable : Observable [T ], f : T => Observable [S ])
22
32
extends Observable [S ] {
23
-
24
33
// scalastyle:off cyclomatic.complexity method.length
25
34
override def subscribe (observer : Observer [_ >: S ]): Unit = {
26
35
observable.subscribe(
27
36
SubscriptionCheckingObserver (
28
37
new Observer [T ] {
29
-
30
- @ volatile
31
- private var outerSubscription : Option [Subscription ] = None
32
- @ volatile
33
- private var nestedSubscription : Option [Subscription ] = None
34
- @ volatile
35
- private var demand : Long = 0
36
- @ volatile
37
- private var onCompleteCalled : Boolean = false
38
+ @ volatile private var outerSubscription : Option [Subscription ] = None
39
+ @ volatile private var demand : Long = 0
40
+ private val state = new AtomicReference [State ](Init )
38
41
39
42
override def onSubscribe (subscription : Subscription ): Unit = {
40
43
val masterSub = new Subscription () {
41
44
override def isUnsubscribed : Boolean = subscription.isUnsubscribed
42
-
43
- def request (n : Long ): Unit = {
45
+ override def unsubscribe () : Unit = subscription.unsubscribe()
46
+ override def request (n : Long ): Unit = {
44
47
require(n > 0L , s " Number requested must be greater than zero: $n" )
45
48
val localDemand = addDemand(n)
46
- val (sub, num) = nestedSubscription.map((_, localDemand)).getOrElse((subscription, 1L ))
47
- sub.request(num)
49
+ state.get() match {
50
+ case Init => subscription.request(1L )
51
+ case WaitingOnChild (s) => s.request(localDemand)
52
+ case _ => // noop
53
+ }
48
54
}
49
-
50
- override def unsubscribe (): Unit = subscription.unsubscribe()
51
55
}
52
-
53
56
outerSubscription = Some (masterSub)
57
+ state.set(Init )
54
58
observer.onSubscribe(masterSub)
55
59
}
56
60
57
61
override def onComplete (): Unit = {
58
- if (! onCompleteCalled) {
59
- onCompleteCalled = true
60
- if (nestedSubscription.isEmpty) observer.onComplete()
62
+ state.get() match {
63
+ case Done => // ok
64
+ case Error => // ok
65
+ case Init if state.compareAndSet(Init , Done ) =>
66
+ observer.onComplete()
67
+ case w @ WaitingOnChild (_) if state.compareAndSet(w, LastChildNotified ) =>
68
+ // letting the child know that we delegate onComplete call to it
69
+ case LastChildNotified =>
70
+ // wait for the child to do the delegated onCompleteCall
71
+ case LastChildResponded if state.compareAndSet(LastChildResponded , Done ) =>
72
+ observer.onComplete()
73
+ case other =>
74
+ // state machine is broken, let's fail
75
+ // normally this won't happen
76
+ throw new IllegalStateException (s " Unexpected state in FlatMapObservable `onComplete` handler: ${other}" )
61
77
}
62
78
}
63
79
64
- override def onError (throwable : Throwable ): Unit = observer.onError(throwable)
80
+ override def onError (throwable : Throwable ): Unit = {
81
+ observer.onError(throwable)
82
+ }
65
83
66
84
override def onNext (tResult : T ): Unit = {
67
85
f(tResult).subscribe(
68
86
new Observer [S ]() {
69
87
override def onError (throwable : Throwable ): Unit = {
70
- nestedSubscription = None
88
+ state.set( Error )
71
89
observer.onError(throwable)
72
90
}
73
91
74
92
override def onSubscribe (subscription : Subscription ): Unit = {
75
- nestedSubscription = Some ( subscription)
93
+ state.set( WaitingOnChild ( subscription) )
76
94
if (demand > 0 ) subscription.request(demand)
77
95
}
78
96
79
97
override def onComplete (): Unit = {
80
- nestedSubscription = None
81
- onCompleteCalled match {
82
- case true => observer.onComplete()
83
- case false if demand > 0 =>
98
+ state.get() match {
99
+ case Done => // no need to call parent's onComplete
100
+ case Error => // no need to call parent's onComplete
101
+ case LastChildNotified if state.compareAndSet(LastChildNotified , LastChildResponded ) =>
102
+ // parent told us to call onComplete
103
+ observer.onComplete()
104
+ case _ if demand > 0 =>
105
+ // otherwise we are not the last child, let's tell the parent
106
+ // it's not dealing with us anymore.
107
+ // Init -> * will be handled by possible later items in the stream
108
+ state.set(Init )
84
109
addDemand(- 1 ) // reduce demand by 1 as it will be incremented by the outerSubscription
85
110
outerSubscription.foreach(_.request(1 ))
86
- case false => // No more demand
111
+ case _ =>
112
+ // no demand
113
+ state.set(Init )
87
114
}
88
115
}
89
116
0 commit comments