diff --git a/examples/streaming/streamingkmeans/README.md b/examples/streaming/streamingkmeans/README.md new file mode 100644 index 000000000..18f22ece2 --- /dev/null +++ b/examples/streaming/streamingkmeans/README.md @@ -0,0 +1,41 @@ +Streaming k-means clustering +============================== +## Introduction +This application is following Streaming k-means clustering on Spark, you can see for details at +. + +The DataSource used is `RandomRBFGenerator`, which is referenced by Huawei `StreamDM` . + +## Gearpump topology +The Gearpump topology is as following: + +![kmeans](https://cloud.githubusercontent.com/assets/5796671/14097520/93a2b498-f5a4-11e5-8df8-ef2b62c3b5ff.PNG) + +The `Source Processor` will produce points by time, then broadcast the point to the `Distribution Processor`. +The number of tasks of the `Distribution Processor` is k, where each task save one center and the corresponding points. +When `Distribution Processor` receives a point from `Source Processor`, it will calculate the distance of this point to its center, and then send the distance along with the point and its `taskId` to the `Collection Processor`. +When `Collection Processor` receives the distance from `Distribution Processor`, it will accumulate the number of current points, determine if it's time to update center, choose the smallest distance and then send the point along with its corresponding `Distribution Processor` taskId by broadcast partitioner. +When `Distribution Processor` receives the result message, task with the corresponding `taskId` will accumulate the point. If `Distribution Processor` receives that it's time to update center, then all the tasks will update its corresponding center. + +This procedure is streaming and the center of cluster will change by time. + +## How to use it +You can used this application by command: + +``` +bin/gear app -jar examples/streamingkmeans-2.11-0.7.7-SNAPSHOT-assembly.jar io.gearpump.streaming.examples.streamingkmeans.StreamingKmeansExample +``` + +As an option, you can configure the clustering task by the following command: + +``` +-k +-dimension +-maxBatch +-maxNumber +-decayFactor +``` + +## Evaluation +The number of task of the `Distribution Processor` is k, where each task saves one cluster center. +It will output the cluster center once they have been updated. diff --git a/examples/streaming/streamingkmeans/src/main/resources/geardefault.conf b/examples/streaming/streamingkmeans/src/main/resources/geardefault.conf new file mode 100644 index 000000000..db7e44955 --- /dev/null +++ b/examples/streaming/streamingkmeans/src/main/resources/geardefault.conf @@ -0,0 +1,6 @@ +gearpump { + serializers { + "io.gearpump.streaming.examples.streamingkmeans.InputMessage" = "" + "io.gearpump.streaming.examples.streamingkmeans.ResultMessage" = "" + } +} \ No newline at end of file diff --git a/examples/streaming/streamingkmeans/src/main/scala/io/gearpump/streaming/examples/streamingkmeans/ClusterCollection.scala b/examples/streaming/streamingkmeans/src/main/scala/io/gearpump/streaming/examples/streamingkmeans/ClusterCollection.scala new file mode 100644 index 000000000..ab7ab7ca8 --- /dev/null +++ b/examples/streaming/streamingkmeans/src/main/scala/io/gearpump/streaming/examples/streamingkmeans/ClusterCollection.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.gearpump.streaming.examples.streamingkmeans + +import io.gearpump.Message +import io.gearpump.cluster.UserConfig +import io.gearpump.streaming.task.{StartTime, Task, TaskContext} + +class ClusterCollection(taskContext: TaskContext, conf: UserConfig) extends Task(taskContext, conf) { + import taskContext.output + + private val k = conf.getInt("k").get + private val maxNumber = conf.getInt("maxNumber").get + + private[streamingkmeans] var minTaskId = 0 + private[streamingkmeans] var minDistance = Double.MaxValue + private[streamingkmeans] var minDistPoint : List[Double] = null + + private[streamingkmeans] var currentNumber = 0 + private[streamingkmeans] var totalNumber = 0 + + override def onStart(startTime: StartTime): Unit = super.onStart(startTime) + + override def onNext(msg: Message): Unit = { + if (null == msg) { + return + } + + val (taskId, distance, point) = msg.msg.asInstanceOf[(Int, Double, List[Double])] + if (distance < minDistance) { + minDistance = distance + minDistPoint = point + minTaskId = taskId + } + + currentNumber += 1 + if (k == currentNumber) { + currentNumber = 0 + totalNumber += 1 + if (maxNumber == totalNumber) { + totalNumber = 0 + output(new Message(new ResultMessage(minTaskId, minDistPoint, true))) + } else { + output(new Message(new ResultMessage(minTaskId, minDistPoint, false))) + } + } + } + + override def onStop(): Unit = super.onStop() +} diff --git a/examples/streaming/streamingkmeans/src/main/scala/io/gearpump/streaming/examples/streamingkmeans/ClusterDistribution.scala b/examples/streaming/streamingkmeans/src/main/scala/io/gearpump/streaming/examples/streamingkmeans/ClusterDistribution.scala new file mode 100644 index 000000000..04935e500 --- /dev/null +++ b/examples/streaming/streamingkmeans/src/main/scala/io/gearpump/streaming/examples/streamingkmeans/ClusterDistribution.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.gearpump.streaming.examples.streamingkmeans + +import java.util.concurrent.LinkedBlockingQueue + +import io.gearpump.Message +import io.gearpump.cluster.UserConfig +import io.gearpump.streaming.task.{StartTime, Task, TaskContext} + +import scala.collection.mutable +import scala.util.Random + +class ClusterDistribution(taskContext: TaskContext, conf: UserConfig) extends Task(taskContext, conf) { + import taskContext.output + + private[streamingkmeans] val dataQueue: LinkedBlockingQueue[List[Double]] = new LinkedBlockingQueue[List[Double]]() + private[streamingkmeans] var isBegin: Boolean = true + + private val decayFactor = conf.getDouble("decayFactor").get + private val dimension = conf.getInt("dimension").get + + private[streamingkmeans] val center: Array[Double] = new Array[Double](dimension) + private[streamingkmeans] val points: mutable.MutableList[List[Double]] = new mutable.MutableList() + private[streamingkmeans] var previousNumber = 0 + private[streamingkmeans] var currentNumber = 0 + + + /** + * init center randomly + */ + private[streamingkmeans] def initCenter(): Unit = { + val random = new Random() + for (i <- center.indices) { + center.update(i, random.nextGaussian()) + } + } + + /** + * The update algorithm uses the "mini-batch" KMeans rule, + * generalized to incorporate forgetfullness (i.e. decay). + * The update rule (for each cluster) is: + * + * {{{ + * c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] + * n_t+t = n_t * a + m_t + * }}} + * + * Where c_t is the previously estimated centroid for that cluster, + * n_t is the number of points assigned to it thus far, x_t is the centroid + * estimated on the current batch, and m_t is the number of points assigned + * to that centroid in the current batch. + * + * The decay factor 'a' scales the contribution of the clusters as estimated thus far, + * by applying a as a discount weighting on the current point when evaluating + * new incoming data. If a=1, all batches are weighted equally. If a=0, new centroids + * are determined entirely by recent data. Lower values correspond to + * more forgetting. + */ + private[streamingkmeans] def updateCenter(): Unit = { + if (0 == currentNumber) { + return + } + + val newCenter: Array[Double] = new Array[Double](dimension) + for (i <- newCenter.indices) { + var sum = 0.0 + for (point <- points) { + sum += point(i) + } + sum /= currentNumber + newCenter.update(i, sum) + } + + for (i <- center.indices) { + center.update(i, + (center(i) * previousNumber * decayFactor + newCenter(i) * currentNumber) + / (previousNumber + currentNumber)) + } + } + + private[streamingkmeans] def getDistance(point: List[Double]): Double = { + var distance = 0.0 + for (i <- 0 until dimension) { + distance += ((point(i) - center(i)) * (point(i) - center(i))) + } + Math.sqrt(distance) + } + + override def onStart(startTime: StartTime): Unit = { + initCenter() + } + + override def onNext(msg: Message): Unit = { + if (null == msg) { + return + } + + val message = msg.msg.asInstanceOf[ClusterMessage] + + message match { + case InputMessage(point) => + if (isBegin) { + isBegin = false + output(new Message((taskContext.taskId.index, getDistance(point), point))) + } else { + dataQueue.put(point) + } + case ResultMessage(taskId, point, doCluster) => + if (taskContext.taskId.index == taskId) { + points += point + currentNumber += 1 + } + if (doCluster) { + updateCenter() + LOG.info(s"task ${taskContext.taskId.index}, center ${center.mkString(",")}") + points.clear() + previousNumber += currentNumber + currentNumber = 0 + } + val newPoint = dataQueue.take() + output(new Message((taskContext.taskId.index, getDistance(newPoint), newPoint))) + } + } + + override def onStop(): Unit = super.onStop() +} diff --git a/examples/streaming/streamingkmeans/src/main/scala/io/gearpump/streaming/examples/streamingkmeans/ClusterMessage.scala b/examples/streaming/streamingkmeans/src/main/scala/io/gearpump/streaming/examples/streamingkmeans/ClusterMessage.scala new file mode 100644 index 000000000..965ee3a60 --- /dev/null +++ b/examples/streaming/streamingkmeans/src/main/scala/io/gearpump/streaming/examples/streamingkmeans/ClusterMessage.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.gearpump.streaming.examples.streamingkmeans + +trait ClusterMessage extends Serializable +case class InputMessage(point: List[Double]) extends ClusterMessage +case class ResultMessage(taskId: Int, point: List[Double], doCluster: Boolean) extends ClusterMessage \ No newline at end of file diff --git a/examples/streaming/streamingkmeans/src/main/scala/io/gearpump/streaming/examples/streamingkmeans/RandomRBFSource.scala b/examples/streaming/streamingkmeans/src/main/scala/io/gearpump/streaming/examples/streamingkmeans/RandomRBFSource.scala new file mode 100644 index 000000000..897b03959 --- /dev/null +++ b/examples/streaming/streamingkmeans/src/main/scala/io/gearpump/streaming/examples/streamingkmeans/RandomRBFSource.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.gearpump.streaming.examples.streamingkmeans + +import io.gearpump.util.LogUtil +import io.gearpump.{TimeStamp, Message} +import io.gearpump.streaming.source.DataSource +import io.gearpump.streaming.task.TaskContext +import org.slf4j.Logger + +import scala.util.Random + +object RandomRBFSource { + private val LOG: Logger = LogUtil.getLogger(classOf[RandomRBFSource]) +} + +/** + * RandomRBFGenerator generates via radial basis function. + * Reference by https://github.com/huawei-noah/streamDM + */ +class RandomRBFSource(k: Int, dimension: Int) extends DataSource { + + class Centroid(center: Array[Double], classLab: Int, stdev: Double) { + val centre = center + val classLabel = classLab + val stdDev = stdev + } + + val centroids = new Array[Centroid](k) + val centroidWeights = new Array[Double](centroids.length) + val instanceRandom: Random = new Random() + + def generateCentroids(): Unit = { + val modelRand: Random = new Random() + + for (i <- centroids.indices) { + val randCentre: Array[Double] = Array.fill[Double](dimension)(modelRand.nextDouble()) + centroids.update(i, new Centroid(randCentre, modelRand.nextInt(k), modelRand.nextDouble())) + centroidWeights.update(i, modelRand.nextDouble()) + } + } + + /** + * choose an index of the weight array randomly. + * @param weights Weight Array + * @param random Random value generator + * @return an index of the weight array + */ + private def chooseRandomIndexBasedOnWeights(weights: Array[Double], random: Random): Int = { + val probSum = weights.sum + val ran = random.nextDouble() * probSum + var index: Int = 0 + var sum: Double = 0.0 + while ((sum <= ran) && (index < weights.length)) { + sum += weights(index) + index += 1 + } + index - 1 + } + + def getPoint: List[Double] = { + val index = chooseRandomIndexBasedOnWeights(centroidWeights, instanceRandom) + val centroid: Centroid = centroids(index) + + val initFeatureVals:Array[Double] = Array.fill[Double](dimension)( + instanceRandom.nextDouble() * 2.0 - 1.0) + val magnitude = Math.sqrt(initFeatureVals.foldLeft(0.0){(a,x) => a + x * x}) + + val desiredMag = instanceRandom.nextGaussian() * centroid.stdDev + val scale = desiredMag / magnitude + + val featureVals = centroid.centre zip initFeatureVals map {case (a,b) => a + b * scale} + featureVals.toList + } + + /** + * open connection to data source + * invoked in onStart() method of [[io.gearpump.streaming.source.DataSourceTask]] + * @param context is the task context at runtime + * @param startTime is the start time of system + */ + override def open(context: TaskContext, startTime: Option[TimeStamp]): Unit = { + generateCentroids() + } + + /** + * close connection to data source. + * invoked in onStop() method of [[io.gearpump.streaming.source.DataSourceTask]] + */ + override def close(): Unit = {} + + /** + * read a number of messages from data source. + * invoked in each onNext() method of [[io.gearpump.streaming.source.DataSourceTask]] + * @param batchSize max number of messages to read + * @return a list of messages wrapped in [[io.gearpump.Message]] + */ + override def read(batchSize: Int): List[Message] = { + List.fill(batchSize)(new Message(new InputMessage(getPoint))) + } +} diff --git a/examples/streaming/streamingkmeans/src/main/scala/io/gearpump/streaming/examples/streamingkmeans/StreamingKmeansExample.scala b/examples/streaming/streamingkmeans/src/main/scala/io/gearpump/streaming/examples/streamingkmeans/StreamingKmeansExample.scala new file mode 100644 index 000000000..f24e107eb --- /dev/null +++ b/examples/streaming/streamingkmeans/src/main/scala/io/gearpump/streaming/examples/streamingkmeans/StreamingKmeansExample.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.gearpump.streaming.examples.streamingkmeans + +import akka.actor.ActorSystem +import io.gearpump.cluster.UserConfig +import io.gearpump.cluster.client.ClientContext +import io.gearpump.cluster.main.{ArgumentsParser, CLIOption, ParseResult} +import io.gearpump.partitioner.BroadcastPartitioner +import io.gearpump.streaming.source.{DataSourceConfig, DataSourceProcessor} +import io.gearpump.streaming.{Processor, StreamApplication} +import io.gearpump.util.Graph.Node +import io.gearpump.util.{AkkaApp, Graph} + +/** + * This application is following streaming-kmeans on Spark + * https://databricks.com/blog/2015/01/28/introducing-streaming-k-means-in-spark-1-2.html + */ +object StreamingKmeansExample extends AkkaApp with ArgumentsParser { + override val options: Array[(String, CLIOption[Any])] = Array( + "k" -> CLIOption[Int]("", required = false, defaultValue = Some(2)), + "maxBatch" -> CLIOption[Int]("", required = false, defaultValue = Some(1000)), + "maxNumber" -> CLIOption[Int]("", required = false, defaultValue = Some(100)), + "decayFactor" -> CLIOption[Double]("", required = false, defaultValue = Some(1.0)), + "dimension" -> CLIOption[Int]("", required = false, defaultValue = Some(2)) + ) + + def application(config: ParseResult, system: ActorSystem) : StreamApplication = { + implicit val actorSystem = system + + val k = config.getInt("k") + val dimension = config.getInt("dimension") + val maxNumber = config.getInt("maxNumber") + val maxBatch = config.getInt("maxBatch") + val decayFactor = config.getString("decayFactor").toDouble + + val userConfig: UserConfig = UserConfig.empty + .withInt("k", k).withInt("dimension", dimension) + .withInt("maxNumber", maxNumber).withDouble("decayFactor", decayFactor) + + val sourceConf: UserConfig = UserConfig.empty + .withInt(DataSourceConfig.SOURCE_READ_BATCH_SIZE, maxBatch) + + val source = new RandomRBFSource(k, dimension) + val sourceProcessor = DataSourceProcessor(source, 1, "data source processor", sourceConf) + val distribution = Processor[ClusterDistribution](k, "distribution processor", userConfig) + val collection = Processor[ClusterCollection](1, "collection processor", userConfig) + + val broadcastPartition = new BroadcastPartitioner + + val app = StreamApplication("streamingkmeans", + Graph(sourceProcessor ~ broadcastPartition ~> distribution ~ broadcastPartition ~> collection ~ broadcastPartition ~> distribution), + UserConfig.empty) + app + } + + override def main(akkaConf: Config, args: Array[String]): Unit = { + val config = parse(args) + val context = ClientContext(akkaConf) + val app = application(config, context.system) + context.submit(app) + context.close() + } +} diff --git a/examples/streaming/streamingkmeans/src/test/scala/io/gearpump/streaming/examples/streamingkmeans/ClusterCollectionSpec.scala b/examples/streaming/streamingkmeans/src/test/scala/io/gearpump/streaming/examples/streamingkmeans/ClusterCollectionSpec.scala new file mode 100644 index 000000000..5adda7571 --- /dev/null +++ b/examples/streaming/streamingkmeans/src/test/scala/io/gearpump/streaming/examples/streamingkmeans/ClusterCollectionSpec.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.gearpump.streaming.examples.streamingkmeans + +import akka.actor.ActorSystem +import akka.testkit.TestProbe +import io.gearpump.Message +import io.gearpump.cluster.{UserConfig, TestUtil} +import io.gearpump.streaming.MockUtil +import org.mockito.Mockito._ +import org.scalatest.{Matchers, WordSpec} + +class ClusterCollectionSpec extends WordSpec with Matchers { + "ClusterCollection" should { + "receive statistics from ClusterDistribution" in { + val taskContext = MockUtil.mockTaskContext + + implicit val system = ActorSystem("test", TestUtil.DEFAULT_CONFIG) + + val mockTaskActor = TestProbe() + + //mock self ActorRef + when(taskContext.self).thenReturn(mockTaskActor.ref) + + val conf = UserConfig.empty.withInt("k", 2).withInt("maxNumber", 1) + val collection: ClusterCollection = new ClusterCollection(taskContext, conf) + assert(collection.currentNumber == 0) + assert(collection.totalNumber == 0) + + collection.onNext(new Message((0, 0.2, List[Double](1.0, 2.0)))) + assert(collection.currentNumber == 1) + assert(collection.totalNumber == 0) + assert(collection.minDistance == 0.2) + assert(collection.minDistPoint == List[Double](1.0, 2.0)) + + collection.onNext(new Message((0, 0.1, List[Double](2.0, 3.0)))) + assert(collection.currentNumber == 0) + assert(collection.totalNumber == 0) + assert(collection.minDistance == 0.1) + assert(collection.minDistPoint == List[Double](2.0, 3.0)) + verify(taskContext, times(1)).output(new Message(new ResultMessage(0, List[Double](2.0, 3.0), true))) + + system.shutdown() + system.awaitTermination() + } + } + +} diff --git a/examples/streaming/streamingkmeans/src/test/scala/io/gearpump/streaming/examples/streamingkmeans/ClusterDistributionSpec.scala b/examples/streaming/streamingkmeans/src/test/scala/io/gearpump/streaming/examples/streamingkmeans/ClusterDistributionSpec.scala new file mode 100644 index 000000000..1d8996fe7 --- /dev/null +++ b/examples/streaming/streamingkmeans/src/test/scala/io/gearpump/streaming/examples/streamingkmeans/ClusterDistributionSpec.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.gearpump.streaming.examples.streamingkmeans + +import akka.actor.ActorSystem +import akka.testkit.TestProbe +import io.gearpump.Message +import io.gearpump.cluster.{TestUtil, UserConfig} +import io.gearpump.streaming.MockUtil +import io.gearpump.streaming.task.StartTime +import org.mockito.Matchers._ +import org.mockito.Mockito._ +import org.scalatest.{Matchers, WordSpec} + +class ClusterDistributionSpec extends WordSpec with Matchers { + "ClusterDistribution" should { + "Receive data point from source" in { + val taskContext = MockUtil.mockTaskContext + + implicit val system = ActorSystem("test", TestUtil.DEFAULT_CONFIG) + + val mockTaskActor = TestProbe() + + //mock self ActorRef + when(taskContext.self).thenReturn(mockTaskActor.ref) + + val conf = UserConfig.empty.withInt("dimension", 2).withDouble("decayFactor", 1.0) + val distribution: ClusterDistribution = new ClusterDistribution(taskContext, conf) + + distribution.onStart(StartTime(0)) + assert(distribution.isBegin) + + val point: List[Double] = List[Double](1.0, 2.0) + val inputMessage: InputMessage = new InputMessage(point) + val taskId: Int = taskContext.taskId.index + val distance: Double = distribution.getDistance(point) + + distribution.onNext(new Message(inputMessage)) + assert(!distribution.isBegin) + assert(distribution.dataQueue.isEmpty) + verify(taskContext, times(1)).output(new Message((taskId, distance, point))) + + distribution.onNext(new Message(inputMessage)) + assert(distribution.dataQueue.size() == 1) + + system.shutdown() + system.awaitTermination() + } + + "Receive result from ClusterCollection" in { + val taskContext = MockUtil.mockTaskContext + + implicit val system = ActorSystem("test", TestUtil.DEFAULT_CONFIG) + + val mockTaskActor = TestProbe() + + //mock self ActorRef + when(taskContext.self).thenReturn(mockTaskActor.ref) + + val conf = UserConfig.empty.withInt("dimension", 2).withDouble("decayFactor", 1.0) + val distribution: ClusterDistribution = new ClusterDistribution(taskContext, conf) + + distribution.onStart(StartTime(0)) + + val taskId: Int = taskContext.taskId.index + val point: List[Double] = List[Double](1.0, 2.0) + val inputMessage: InputMessage = new InputMessage(point) + val center: Array[Double] = distribution.center.clone() + + distribution.onNext(new Message(inputMessage)) + distribution.onNext(new Message(inputMessage)) + distribution.onNext(new Message(inputMessage)) + assert(distribution.dataQueue.size() == 2) + + distribution.onNext(new Message(new ResultMessage(taskId + 1, point, true))) + assert(distribution.currentNumber == 0 && distribution.points.isEmpty) + assert(distribution.dataQueue.size() == 1) + assert(distribution.center.sameElements(center)) + + distribution.onNext(new Message(new ResultMessage(taskId, point, true))) + assert(distribution.currentNumber == 0 && distribution.points.isEmpty) + assert(distribution.dataQueue.isEmpty) + assert(!distribution.center.sameElements(center)) + + system.shutdown() + system.awaitTermination() + } + } +} diff --git a/examples/streaming/streamingkmeans/src/test/scala/io/gearpump/streaming/examples/streamingkmeans/RandomRBFSourceSpec.scala b/examples/streaming/streamingkmeans/src/test/scala/io/gearpump/streaming/examples/streamingkmeans/RandomRBFSourceSpec.scala new file mode 100644 index 000000000..956ae81c6 --- /dev/null +++ b/examples/streaming/streamingkmeans/src/test/scala/io/gearpump/streaming/examples/streamingkmeans/RandomRBFSourceSpec.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.gearpump.streaming.examples.streamingkmeans + +import org.scalatest.FlatSpec + +class RandomRBFSourceSpec extends FlatSpec { + it should "getPoint in RandomRBFSource should produce data points with dimension `dimension`" in { + val k: Int = 2 + val dimension: Int = 3 + val randomRBFSource: RandomRBFSource = new RandomRBFSource(k, dimension) + randomRBFSource.generateCentroids() + assert(randomRBFSource.getPoint.length == dimension) + } + + it should "number of centroids in RandomRBFSource should equal with `k`" in { + val k: Int = 2 + val dimension: Int = 3 + val randomRBFSource: RandomRBFSource = new RandomRBFSource(k, dimension) + randomRBFSource.generateCentroids() + assert(randomRBFSource.centroids.length == k) + } +} diff --git a/examples/streaming/streamingkmeans/src/test/scala/io/gearpump/streaming/examples/streamingkmeans/StreamingKmeansSpec.scala b/examples/streaming/streamingkmeans/src/test/scala/io/gearpump/streaming/examples/streamingkmeans/StreamingKmeansSpec.scala new file mode 100644 index 000000000..8c94cd304 --- /dev/null +++ b/examples/streaming/streamingkmeans/src/test/scala/io/gearpump/streaming/examples/streamingkmeans/StreamingKmeansSpec.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.gearpump.streaming.examples.streamingkmeans + +import io.gearpump.cluster.ClientToMaster.SubmitApplication +import io.gearpump.cluster.MasterToClient.SubmitApplicationResult +import io.gearpump.cluster.{MasterHarness, TestUtil} +import org.scalatest.prop.PropertyChecks +import org.scalatest.{BeforeAndAfter, Matchers, PropSpec} + +import scala.concurrent.Future +import scala.util.Success + +class StreamingKmeansSpec extends PropSpec with PropertyChecks with Matchers with BeforeAndAfter with MasterHarness { + before { + startActorSystem() + } + + after { + shutdownActorSystem() + } + + override def config = TestUtil.DEFAULT_CONFIG + + property("StreamingKmeansExample should succeed to submit application with required arguments") { + val requiredArgs = Array.empty[String] + val optionalArgs = Array( + "-k", "2", + "-dimension", "2", + "-maxBatch", "1000", + "-maxNumber", "100", + "-decayFactor", "1.0") + + val args = { + Table( + ("requiredArgs", "optionalArgs"), + (requiredArgs, optionalArgs) + ) + } + val masterReceiver = createMockMaster() + forAll(args) { (requiredArgs: Array[String], optionalArgs: Array[String]) => + + val args = requiredArgs ++ optionalArgs + + Future {StreamingKmeansExample.main(masterConfig, args)} + + masterReceiver.expectMsgType[SubmitApplication](PROCESS_BOOT_TIME) + masterReceiver.reply(SubmitApplicationResult(Success(0))) + } + } +} diff --git a/project/BuildExample.scala b/project/BuildExample.scala index 3bd93c3e9..39f8951e3 100644 --- a/project/BuildExample.scala +++ b/project/BuildExample.scala @@ -27,7 +27,7 @@ object BuildExample extends sbt.Build { id = "gearpump-examples", base = file("examples"), settings = commonSettings ++ noPublish - ) aggregate(wordcount, wordcountJava, complexdag, sol, fsio, examples_kafka, + ) aggregate(wordcount, wordcountJava, complexdag, sol, fsio, examples_kafka, streamingkmeans, distributedshell, stockcrawler, transport, examples_state, pagerank, distributeservice) lazy val wordcountJava = Project( @@ -56,6 +56,17 @@ object BuildExample extends sbt.Build { ) ) dependsOn(streaming % "test->test; provided", daemon % "test->test; provided") + lazy val streamingkmeans = Project( + id = "gearpump-examples-streamingkmeans", + base = file("examples/streaming/streamingkmeans"), + settings = commonSettings ++ noPublish ++ myAssemblySettings ++ + Seq( + mainClass in (Compile, packageBin) := Some("io.gearpump.streaming.examples.streamingkmeans.StreamingKmeansExample"), + target in assembly := baseDirectory.value.getParentFile.getParentFile / "target" / + CrossVersion.binaryScalaVersion(scalaVersion.value) + ) + ) dependsOn (streaming % "test->test; provided", daemon % "test->test; provided") + lazy val sol = Project( id = "gearpump-examples-sol", base = file("examples/streaming/sol"), @@ -229,4 +240,4 @@ object BuildExample extends sbt.Build { CrossVersion.binaryScalaVersion(scalaVersion.value) ) ) dependsOn (streaming % "test->test; provided") -} \ No newline at end of file +}