diff --git a/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuanta.scala b/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuanta.scala index 0dd247a3b..ab41cba6d 100644 --- a/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuanta.scala +++ b/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuanta.scala @@ -588,6 +588,18 @@ class DataQuanta[Out: ClassTag](val operator: ElementaryOperator, outputIndex: I predictOperator } + def predictJava[ThatOut: ClassTag, Result: ClassTag]( + that: DataQuanta[ThatOut] + ): DataQuanta[Result] = { + val predictOperator = new PredictOperator( + implicitly[ClassTag[ThatOut]].runtimeClass, + implicitly[ClassTag[Result]].runtimeClass + ) + this.connectTo(predictOperator, 0) + that.connectTo(predictOperator, 1) + predictOperator + } + def dlTraining[ThatOut: ClassTag]( model: DLModel, option: DLTrainingOperator.Option, @@ -616,6 +628,24 @@ class DataQuanta[Out: ClassTag](val operator: ElementaryOperator, outputIndex: I dlTrainingOperator } + + def dlTrainingJava[ThatOut: ClassTag]( + model: DLModel, + option: DLTrainingOperator.Option, + that: DataQuanta[ThatOut] + ): DataQuanta[DLModel] = { + val dlTrainingOperator = new DLTrainingOperator( + model, + option, + implicitly[ClassTag[Out]].runtimeClass, + implicitly[ClassTag[ThatOut]].runtimeClass + ) + + this.connectTo(dlTrainingOperator, 0) + that.connectTo(dlTrainingOperator, 1) + dlTrainingOperator + } + /** * Feeds this and a further instance into a [[CoGroupOperator]]. * diff --git a/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuantaBuilder.scala b/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuantaBuilder.scala index 044ec7ba0..3a20c36ce 100644 --- a/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuantaBuilder.scala +++ b/wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuantaBuilder.scala @@ -27,7 +27,8 @@ import java.util.{Collection => JavaCollection} import org.apache.wayang.api.graph.{Edge, EdgeDataQuantaBuilder, EdgeDataQuantaBuilderDecorator} import org.apache.wayang.api.util.{DataQuantaBuilderCache, TypeTrap} import org.apache.wayang.basic.data.{Record, Tuple2 => RT2} -import org.apache.wayang.basic.operators.{GlobalReduceOperator, LocalCallbackSink, MapOperator, SampleOperator} +import org.apache.wayang.basic.model.{DLModel, Model} +import org.apache.wayang.basic.operators.{DLTrainingOperator, GlobalReduceOperator, LocalCallbackSink, MapOperator, SampleOperator} import org.apache.wayang.commons.util.profiledb.model.Experiment import org.apache.wayang.core.function.FunctionDescriptor.{SerializableBiFunction, SerializableBinaryOperator, SerializableFunction, SerializableIntUnaryOperator, SerializablePredicate} import org.apache.wayang.core.optimizer.ProbabilisticDoubleInterval @@ -273,6 +274,30 @@ trait DataQuantaBuilder[+This <: DataQuantaBuilder[_, Out], Out] extends Logging thatKeyUdf: SerializableFunction[ThatOut, Key]) = new JoinDataQuantaBuilder(this, that, thisKeyUdf, thatKeyUdf) + /** + * Feed the built [[DataQuanta]] of this and the given instance into a + * [[org.apache.wayang.basic.operators.DLTrainingOperator]]. + * + * @param that the other [[DataQuantaBuilder]] to join with + * @param model model for the [[org.apache.wayang.basic.operators.DLTrainingOperator]] + * @param option option for the [[org.apache.wayang.basic.operators.DLTrainingOperator]] + * @return a [[DLTrainingDataQuantaBuilder]] + */ + def dlTraining[ThatOut](that: DataQuantaBuilder[_, ThatOut], + model: DLModel, + option: DLTrainingOperator.Option) = + new DLTrainingDataQuantaBuilder(this, that, model, option) + + /** + * Feed the built [[DataQuanta]] of this and the given instance into a + * [[org.apache.wayang.basic.operators.PredictOperator]]. + * + * @param that the other [[DataQuantaBuilder]] to join with + * @return a [[PredictDataQuantaBuilder]] + */ + def predict[ThatOut, Result](that: DataQuantaBuilder[_, ThatOut], resultType: Class[Result]) = + new PredictDataQuantaBuilder(this.asInstanceOf[DataQuantaBuilder[_, Model]], that, resultType) + /** * Feed the built [[DataQuanta]] of this and the given instance into a * [[org.apache.wayang.basic.operators.CoGroupOperator]]. @@ -1336,6 +1361,53 @@ class JoinDataQuantaBuilder[In0, In1, Key](inputDataQuanta0: DataQuantaBuilder[_ } +/** + * [[DataQuantaBuilder]] implementation for [[org.apache.wayang.basic.operators.DLTrainingOperator]]s. + * + * @param inputDataQuanta0 [[DataQuantaBuilder]] for the first input [[DataQuanta]] + * @param inputDataQuanta1 [[DataQuantaBuilder]] for the first input [[DataQuanta]] + * @param model model for the [[org.apache.wayang.basic.operators.DLTrainingOperator]] + * @param option option for the [[org.apache.wayang.basic.operators.DLTrainingOperator]] + */ +class DLTrainingDataQuantaBuilder[In0, In1](inputDataQuanta0: DataQuantaBuilder[_, In0], + inputDataQuanta1: DataQuantaBuilder[_, In1], + model: DLModel, + option: DLTrainingOperator.Option) + (implicit javaPlanBuilder: JavaPlanBuilder) + extends BasicDataQuantaBuilder[DLTrainingDataQuantaBuilder[In0, In1], DLModel] { + + // Since we are currently not looking at type parameters, we can statically determine the output type. + locally { + this.outputTypeTrap.dataSetType = dataSetType[DLModel] + } + + override protected def build = + inputDataQuanta0.dataQuanta() + .dlTrainingJava(model, option, inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag) +} + +/** + * [[DataQuantaBuilder]] implementation for [[org.apache.wayang.basic.operators.PredictOperator]]s. + * + * @param inputDataQuanta0 [[DataQuantaBuilder]] for the first input [[DataQuanta]] + * @param inputDataQuanta1 [[DataQuantaBuilder]] for the first input [[DataQuanta]] + */ +class PredictDataQuantaBuilder[In1, Out](inputDataQuanta0: DataQuantaBuilder[_, Model], + inputDataQuanta1: DataQuantaBuilder[_, In1], + outClass: Class[Out]) + (implicit javaPlanBuilder: JavaPlanBuilder) + extends BasicDataQuantaBuilder[PredictDataQuantaBuilder[In1, Out], Out] { + + // Since we are currently not looking at type parameters, we can statically determine the output type. + locally { + this.outputTypeTrap.dataSetType = dataSetType[Out] + } + + override protected def build = + inputDataQuanta0.dataQuanta(). + predictJava(inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag, ClassTag.apply(outClass)) +} + /** * [[DataQuantaBuilder]] implementation for [[org.apache.wayang.basic.operators.CoGroupOperator]]s. * diff --git a/wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIrisIT.java b/wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIrisIT.java index f8c08b75c..247d7ead6 100644 --- a/wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIrisIT.java +++ b/wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIrisIT.java @@ -57,7 +57,7 @@ public class TensorflowIrisIT { "Iris-virginica", 2 ); - @Ignore + @Test public void test() { final Tuple trainSource = fileOperation(TRAIN_PATH, true); final Tuple testSource = fileOperation(TEST_PATH, false); diff --git a/wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIrisScalaLikeApiIT.java b/wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIrisScalaLikeApiIT.java new file mode 100644 index 000000000..d7da2ca7d --- /dev/null +++ b/wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIrisScalaLikeApiIT.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.tests; + +import org.apache.wayang.api.*; +import org.apache.wayang.basic.model.DLModel; +import org.apache.wayang.basic.model.op.*; +import org.apache.wayang.basic.model.op.nn.CrossEntropyLoss; +import org.apache.wayang.basic.model.op.nn.Linear; +import org.apache.wayang.basic.model.op.nn.Sigmoid; +import org.apache.wayang.basic.model.optimizer.Adam; +import org.apache.wayang.basic.model.optimizer.Optimizer; +import org.apache.wayang.basic.operators.DLTrainingOperator; +import org.apache.wayang.core.api.WayangContext; +import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.java.Java; +import org.apache.wayang.tensorflow.Tensorflow; +import org.junit.Test; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Random; + +/** + * Test the Tensorflow integration with Wayang. + * Note: this test fails on M1 Macs because of Tensorflow-Java incompatibility. + */ +public class TensorflowIrisScalaLikeApiIT { + + public static URI TRAIN_PATH = createUri("/iris_train.csv"); + public static URI TEST_PATH = createUri("/iris_test.csv"); + + public static Map LABEL_MAP = Map.of( + "Iris-setosa", 0, + "Iris-versicolor", 1, + "Iris-virginica", 2 + ); + + @Test + public void test() { + WayangContext wayangContext = new WayangContext() + .with(Java.basicPlugin()) + .with(Tensorflow.plugin()); + + JavaPlanBuilder plan = new JavaPlanBuilder(wayangContext); + + final Tuple, DataQuantaBuilder> trainSource = + fileOperation(plan, TRAIN_PATH, true); + final Tuple, DataQuantaBuilder> testSource = + fileOperation(plan, TEST_PATH, false); + + /* training features */ + DataQuantaBuilder trainXSource = trainSource.field0; + + /* training labels */ + DataQuantaBuilder trainYSource = trainSource.field1; + + /* test features */ + DataQuantaBuilder testXSource = testSource.field0; + + /* test labels */ + DataQuantaBuilder testYSource = testSource.field1; + + /* model */ + Op l1 = new Linear(4, 32, true); + Op s1 = new Sigmoid(); + Op l2 = new Linear(32, 3, true); + s1.with(l1.with(new Input(Input.Type.FEATURES))); + l2.with(s1); + + DLModel model = new DLModel(l2); + + /* training options */ + // 1. loss function + Op criterion = new CrossEntropyLoss(3); + criterion.with( + new Input(Input.Type.PREDICTED, Op.DType.FLOAT32), + new Input(Input.Type.LABEL, Op.DType.INT32) + ); + + // 2. accuracy calculation function + Op acc = new Mean(0); + acc.with(new Cast(Op.DType.FLOAT32).with(new Eq().with( + new ArgMax(1).with(new Input(Input.Type.PREDICTED, Op.DType.FLOAT32)), + new Input(Input.Type.LABEL, Op.DType.INT32) + ))); + + // 3. optimizer with learning rate + Optimizer optimizer = new Adam(0.1f); + + // 4. batch size + int batchSize = 45; + + // 5. epoch + int epoch = 10; + + DLTrainingOperator.Option option = new DLTrainingOperator.Option(criterion, optimizer, batchSize, epoch); + option.setAccuracyCalculation(acc); + + /* training operator */ + DLTrainingDataQuantaBuilder trainingOperator = + trainXSource.dlTraining(trainYSource, model, option); + + /* predict operator */ + PredictDataQuantaBuilder predictOperator = + trainingOperator.predict(testXSource, float[].class); + + /* map to label */ + MapDataQuantaBuilder mapOperator = predictOperator.map(array -> { + int maxIdx = 0; + float maxVal = array[0]; + for (int i = 1; i < array.length; i++) { + if (array[i] > maxVal) { + maxIdx = i; + maxVal = array[i]; + } + } + return maxIdx; + }); + + /* sink */ + List predicted = new ArrayList<>(mapOperator.collect()); + // fixme: Currently, wayang's scala-like api only supports a single collect, + // so it is not possible to collect multiple result lists in a single plan. +// List groundTruth = new ArrayList<>(testYSource.collect()); + + System.out.println("predicted: " + predicted); +// System.out.println("ground truth: " + groundTruth); + +// float success = 0; +// for (int i = 0; i < predicted.size(); i++) { +// if (predicted.get(i).equals(groundTruth.get(i))) { +// success += 1; +// } +// } +// System.out.println("test accuracy: " + success / predicted.size()); + } + + public static Tuple, DataQuantaBuilder> + fileOperation(JavaPlanBuilder plan, URI uri, boolean random) { + DataQuantaBuilder textFileSource = plan.readTextFile(uri.toString()); + + if (random) { + Random r = new Random(); + textFileSource = textFileSource.sort(e -> r.nextInt()); + } + + MapDataQuantaBuilder> mapXY = textFileSource.map(line -> { + String[] parts = line.split(","); + float[] x = new float[parts.length - 1]; + for (int i = 0; i < x.length; i++) { + x[i] = Float.parseFloat(parts[i]); + } + int y = LABEL_MAP.get(parts[parts.length - 1]); + return new Tuple<>(x, y); + }); + + MapDataQuantaBuilder, float[]> mapX = mapXY.map(tuple -> tuple.field0); + MapDataQuantaBuilder, Integer> mapY = mapXY.map(tuple -> tuple.field1); + + return new Tuple<>(mapX, mapY); + } + + public static URI createUri(String resourcePath) { + try { + return TensorflowIrisScalaLikeApiIT.class.getResource(resourcePath).toURI(); + } catch (URISyntaxException e) { + throw new IllegalArgumentException("Illegal URI.", e); + } + } +}