Skip to content

Commit fe3cc75

Browse files
committed
Merge pull request #2552 from akarnokd/OperatorPublishRequestFix
Publish: fixed incorrect subscriber requested accounting
2 parents 19c96cd + aeb879a commit fe3cc75

File tree

2 files changed

+109
-49
lines changed

2 files changed

+109
-49
lines changed

src/main/java/rx/internal/operators/OperatorPublish.java

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,6 @@ public void onNext(T t) {
206206
*/
207207
private static class State<T> {
208208
private long outstandingRequests = -1;
209-
private long emittedSinceRequest = 0;
210209
private OriginSubscriber<T> origin;
211210
// using AtomicLong to simplify mutating it, not for thread-safety since we're synchronizing access to this class
212211
// using LinkedHashMap so the order of Subscribers having onNext invoked is deterministic (same each time the code is run)
@@ -225,15 +224,13 @@ public synchronized void setOrigin(OriginSubscriber<T> o) {
225224
public synchronized boolean canEmitWithDecrement() {
226225
if (outstandingRequests > 0) {
227226
outstandingRequests--;
228-
emittedSinceRequest++;
229227
return true;
230228
}
231229
return false;
232230
}
233231

234232
public synchronized void incrementOutstandingAfterFailedEmit() {
235233
outstandingRequests++;
236-
emittedSinceRequest--;
237234
}
238235

239236
public synchronized Subscriber<? super T>[] getSubscribers() {
@@ -243,50 +240,55 @@ public synchronized Subscriber<? super T>[] getSubscribers() {
243240
/**
244241
* @return long outstandingRequests
245242
*/
246-
public synchronized long requestFromSubscriber(Subscriber<? super T> subscriber, Long request) {
247-
AtomicLong r = ss.get(subscriber);
243+
public synchronized long requestFromSubscriber(Subscriber<? super T> subscriber, long request) {
244+
Map<Subscriber<? super T>, AtomicLong> subs = ss;
245+
AtomicLong r = subs.get(subscriber);
248246
if (r == null) {
249-
ss.put(subscriber, new AtomicLong(request));
247+
subs.put(subscriber, new AtomicLong(request));
250248
} else {
251-
if (r.get() != Long.MAX_VALUE) {
252-
if (request == Long.MAX_VALUE) {
253-
r.set(Long.MAX_VALUE);
254-
} else {
255-
r.addAndGet(request.longValue());
249+
do {
250+
long current = r.get();
251+
if (current == Long.MAX_VALUE) {
252+
break;
256253
}
257-
}
254+
long u = current + request;
255+
if (u < 0) {
256+
u = Long.MAX_VALUE;
257+
}
258+
if (r.compareAndSet(current, u)) {
259+
break;
260+
}
261+
} while (true);
258262
}
259263

260-
return resetAfterSubscriberUpdate();
264+
return resetAfterSubscriberUpdate(subs);
261265
}
262266

263267
public synchronized void removeSubscriber(Subscriber<? super T> subscriber) {
264-
ss.remove(subscriber);
265-
resetAfterSubscriberUpdate();
268+
Map<Subscriber<? super T>, AtomicLong> subs = ss;
269+
subs.remove(subscriber);
270+
resetAfterSubscriberUpdate(subs);
266271
}
267272

268273
@SuppressWarnings("unchecked")
269-
private long resetAfterSubscriberUpdate() {
270-
subscribers = new Subscriber[ss.size()];
274+
private long resetAfterSubscriberUpdate(Map<Subscriber<? super T>, AtomicLong> subs) {
275+
Subscriber<? super T>[] subscriberArray = new Subscriber[subs.size()];
271276
int i = 0;
272-
for (Subscriber<? super T> s : ss.keySet()) {
273-
subscribers[i++] = s;
274-
}
275-
276277
long lowest = -1;
277-
for (AtomicLong l : ss.values()) {
278-
// decrement all we have emitted since last request
279-
long c = l.addAndGet(-emittedSinceRequest);
278+
for (Map.Entry<Subscriber<? super T>, AtomicLong> e : subs.entrySet()) {
279+
subscriberArray[i++] = e.getKey();
280+
AtomicLong l = e.getValue();
281+
long c = l.get();
280282
if (lowest == -1 || c < lowest) {
281283
lowest = c;
282284
}
283285
}
286+
this.subscribers = subscriberArray;
284287
/*
285288
* when receiving a request from a subscriber we reset 'outstanding' to the lowest of all subscribers
286289
*/
287290
outstandingRequests = lowest;
288-
emittedSinceRequest = 0;
289-
return outstandingRequests;
291+
return lowest;
290292
}
291293
}
292294

@@ -299,7 +301,7 @@ private static class RequestHandler<T> {
299301
@SuppressWarnings("rawtypes")
300302
static final AtomicLongFieldUpdater<RequestHandler> WIP = AtomicLongFieldUpdater.newUpdater(RequestHandler.class, "wip");
301303

302-
public void requestFromChildSubscriber(Subscriber<? super T> subscriber, Long request) {
304+
public void requestFromChildSubscriber(Subscriber<? super T> subscriber, long request) {
303305
state.requestFromSubscriber(subscriber, request);
304306
OriginSubscriber<T> originSubscriber = state.getOrigin();
305307
if(originSubscriber != null) {
@@ -333,6 +335,11 @@ private void requestMoreAfterEmission(int emitted) {
333335

334336
public void drainQueue(OriginSubscriber<T> originSubscriber) {
335337
if (WIP.getAndIncrement(this) == 0) {
338+
State<T> localState = state;
339+
Map<Subscriber<? super T>, AtomicLong> localMap = localState.ss;
340+
RxRingBuffer localBuffer = originSubscriber.buffer;
341+
NotificationLite<T> nl = notifier;
342+
336343
int emitted = 0;
337344
do {
338345
/*
@@ -345,26 +352,24 @@ public void drainQueue(OriginSubscriber<T> originSubscriber) {
345352
* If we want to batch this then we need to account for new subscribers arriving with a lower request count
346353
* concurrently while iterating the batch ... or accept that they won't
347354
*/
355+
348356
while (true) {
349-
boolean shouldEmit = state.canEmitWithDecrement();
357+
boolean shouldEmit = localState.canEmitWithDecrement();
350358
if (!shouldEmit) {
351359
break;
352360
}
353-
Object o = originSubscriber.buffer.poll();
361+
Object o = localBuffer.poll();
354362
if (o == null) {
355363
// nothing in buffer so increment outstanding back again
356-
state.incrementOutstandingAfterFailedEmit();
364+
localState.incrementOutstandingAfterFailedEmit();
357365
break;
358366
}
359367

360-
if (notifier.isCompleted(o)) {
361-
for (Subscriber<? super T> s : state.getSubscribers()) {
362-
notifier.accept(s, o);
363-
}
364-
365-
} else {
366-
for (Subscriber<? super T> s : state.getSubscribers()) {
367-
notifier.accept(s, o);
368+
for (Subscriber<? super T> s : localState.getSubscribers()) {
369+
AtomicLong req = localMap.get(s);
370+
if (req != null) { // null req indicates a concurrent unsubscription happened
371+
nl.accept(s, o);
372+
req.decrementAndGet();
368373
}
369374
}
370375
emitted++;

src/test/java/rx/internal/operators/OperatorPublishTest.java

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,17 @@
1515
*/
1616
package rx.internal.operators;
1717

18-
import static org.junit.Assert.assertEquals;
19-
import static org.junit.Assert.fail;
18+
import static org.junit.Assert.*;
2019

2120
import java.util.Arrays;
22-
import java.util.concurrent.CountDownLatch;
23-
import java.util.concurrent.TimeUnit;
24-
import java.util.concurrent.atomic.AtomicInteger;
21+
import java.util.concurrent.*;
22+
import java.util.concurrent.atomic.*;
2523

2624
import org.junit.Test;
2725

28-
import rx.Observable;
26+
import rx.*;
2927
import rx.Observable.OnSubscribe;
30-
import rx.Subscriber;
31-
import rx.Subscription;
32-
import rx.functions.Action0;
33-
import rx.functions.Action1;
34-
import rx.functions.Func1;
28+
import rx.functions.*;
3529
import rx.internal.util.RxRingBuffer;
3630
import rx.observables.ConnectableObservable;
3731
import rx.observers.TestSubscriber;
@@ -187,4 +181,65 @@ public Boolean call(Integer i) {
187181
System.out.println(ts.getOnNextEvents());
188182
}
189183

184+
@Test(timeout = 10000)
185+
public void testBackpressureTwoConsumers() {
186+
final AtomicInteger sourceEmission = new AtomicInteger();
187+
final AtomicBoolean sourceUnsubscribed = new AtomicBoolean();
188+
final Observable<Integer> source = Observable.range(1, 100)
189+
.doOnNext(new Action1<Integer>() {
190+
@Override
191+
public void call(Integer t1) {
192+
sourceEmission.incrementAndGet();
193+
}
194+
})
195+
.doOnUnsubscribe(new Action0() {
196+
@Override
197+
public void call() {
198+
sourceUnsubscribed.set(true);
199+
}
200+
}).share();
201+
;
202+
203+
final AtomicBoolean child1Unsubscribed = new AtomicBoolean();
204+
final AtomicBoolean child2Unsubscribed = new AtomicBoolean();
205+
206+
final TestSubscriber<Integer> ts2 = new TestSubscriber<Integer>();
207+
208+
final TestSubscriber<Integer> ts1 = new TestSubscriber<Integer>() {
209+
@Override
210+
public void onNext(Integer t) {
211+
if (getOnNextEvents().size() == 2) {
212+
source.doOnUnsubscribe(new Action0() {
213+
@Override
214+
public void call() {
215+
child2Unsubscribed.set(true);
216+
}
217+
}).take(5).subscribe(ts2);
218+
}
219+
super.onNext(t);
220+
}
221+
};
222+
223+
source.doOnUnsubscribe(new Action0() {
224+
@Override
225+
public void call() {
226+
child1Unsubscribed.set(true);
227+
}
228+
}).take(5).subscribe(ts1);
229+
230+
ts1.awaitTerminalEvent();
231+
ts2.awaitTerminalEvent();
232+
233+
ts1.assertNoErrors();
234+
ts2.assertNoErrors();
235+
236+
assertTrue(sourceUnsubscribed.get());
237+
assertTrue(child1Unsubscribed.get());
238+
assertTrue(child2Unsubscribed.get());
239+
240+
ts1.assertReceivedOnNext(Arrays.asList(1, 2, 3, 4, 5));
241+
ts2.assertReceivedOnNext(Arrays.asList(4, 5, 6, 7, 8));
242+
243+
assertEquals(8, sourceEmission.get());
244+
}
190245
}

0 commit comments

Comments
 (0)