Skip to content

Commit ee22026

Browse files
He-Pinmdedetrich
authored andcommitted
chore: Use array list for better performance in BroadcastHub (#2262)
* chore: Use array list for better performance in BroadcastHub * chore: add benchmark (cherry picked from commit 3cfe37f)
1 parent a4f982e commit ee22026

File tree

2 files changed

+109
-14
lines changed

2 files changed

+109
-14
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.pekko.stream
19+
20+
import com.typesafe.config.ConfigFactory
21+
import org.apache.pekko.NotUsed
22+
import org.apache.pekko.actor.ActorSystem
23+
import org.apache.pekko.remote.artery.{ BenchTestSource, LatchSink }
24+
import org.apache.pekko.stream.scaladsl._
25+
import org.apache.pekko.stream.testkit.scaladsl.StreamTestKit
26+
import org.openjdk.jmh.annotations._
27+
28+
import java.util.concurrent.{ CountDownLatch, TimeUnit }
29+
import scala.concurrent.Await
30+
import scala.concurrent.duration._
31+
32+
object BroadcastHubBenchmark {
33+
final val OperationsPerInvocation = 100000
34+
}
35+
36+
@State(Scope.Benchmark)
37+
@OutputTimeUnit(TimeUnit.SECONDS)
38+
@BenchmarkMode(Array(Mode.Throughput))
39+
class BroadcastHubBenchmark {
40+
import BroadcastHubBenchmark._
41+
42+
val config = ConfigFactory.parseString("""
43+
pekko.actor.default-dispatcher {
44+
executor = "fork-join-executor"
45+
fork-join-executor {
46+
parallelism-factor = 1
47+
}
48+
}
49+
""")
50+
51+
implicit val system: ActorSystem = ActorSystem("BroadcastHubBenchmark", config)
52+
import system.dispatcher
53+
54+
var testSource: Source[java.lang.Integer, NotUsed] = _
55+
56+
@Param(Array("64", "256"))
57+
var parallelism = 0
58+
59+
@Setup
60+
def setup(): Unit = {
61+
// eager init of materializer
62+
SystemMaterializer(system).materializer
63+
testSource = Source.fromGraph(new BenchTestSource(OperationsPerInvocation))
64+
}
65+
66+
@TearDown
67+
def shutdown(): Unit = {
68+
Await.result(system.terminate(), 5.seconds)
69+
}
70+
71+
@Benchmark
72+
@OperationsPerInvocation(OperationsPerInvocation)
73+
def broadcast(): Unit = {
74+
val latch = new CountDownLatch(parallelism)
75+
val broadcastSink =
76+
BroadcastHub.sink[java.lang.Integer](bufferSize = parallelism, startAfterNrOfConsumers = parallelism)
77+
val sink = new LatchSink(OperationsPerInvocation, latch)
78+
val source = testSource.runWith(broadcastSink)
79+
var idx = 0
80+
while (idx < parallelism) {
81+
source.runWith(sink)
82+
idx += 1
83+
}
84+
awaitLatch(latch)
85+
}
86+
87+
private def awaitLatch(latch: CountDownLatch): Unit = {
88+
if (!latch.await(30, TimeUnit.SECONDS)) {
89+
StreamTestKit.printDebugDump(SystemMaterializer(system).materializer.supervisor)
90+
throw new RuntimeException("Latch didn't complete in time")
91+
}
92+
}
93+
94+
}

stream/src/main/scala/org/apache/pekko/stream/scaladsl/Hub.scala

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,8 @@ private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: I
543543
* a wakeup and update their position at the same time.
544544
*
545545
*/
546-
private[this] val consumerWheel = Array.fill[List[Consumer]](bufferSize * 2)(Nil)
546+
private[this] val consumerWheel =
547+
Array.fill[java.util.ArrayList[Consumer]](bufferSize * 2)(new util.ArrayList[Consumer]())
547548
private[this] var activeConsumers = 0
548549

549550
override def preStart(): Unit = {
@@ -651,8 +652,10 @@ private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: I
651652
}
652653

653654
// Notify registered consumers
654-
consumerWheel.iterator.flatMap(_.iterator).foreach { consumer =>
655-
consumer.callback.invoke(failMessage)
655+
var idx = 0
656+
while (idx < consumerWheel.length) {
657+
consumerWheel(idx).forEach(_.callback.invoke(failMessage))
658+
idx += 1
656659
}
657660
failStage(ex)
658661
}
@@ -666,18 +669,16 @@ private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: I
666669
private def findAndRemoveConsumer(id: Long, offset: Int): Consumer = {
667670
// TODO: Try to eliminate modulo division somehow...
668671
val wheelSlot = offset & WheelMask
669-
var consumersInSlot = consumerWheel(wheelSlot)
670-
// debug(s"consumers before removal $consumersInSlot")
671-
var remainingConsumersInSlot: List[Consumer] = Nil
672+
val consumersInSlot = consumerWheel(wheelSlot)
672673
var removedConsumer: Consumer = null
673-
674-
while (consumersInSlot.nonEmpty) {
675-
val consumer = consumersInSlot.head
676-
if (consumer.id != id) remainingConsumersInSlot = consumer :: remainingConsumersInSlot
677-
else removedConsumer = consumer
678-
consumersInSlot = consumersInSlot.tail
674+
if (consumersInSlot.size() > 0) {
675+
consumersInSlot.removeIf(consumer => {
676+
if (consumer.id == id) {
677+
removedConsumer = consumer
678+
true
679+
} else false
680+
})
679681
}
680-
consumerWheel(wheelSlot) = remainingConsumersInSlot
681682
removedConsumer
682683
}
683684

@@ -708,7 +709,7 @@ private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: I
708709

709710
private def addConsumer(consumer: Consumer, offset: Int): Unit = {
710711
val slot = offset & WheelMask
711-
consumerWheel(slot) = consumer :: consumerWheel(slot)
712+
consumerWheel(slot).add(consumer)
712713
}
713714

714715
/*

0 commit comments

Comments
 (0)