Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,18 @@ class DataQuanta[Out: ClassTag](val operator: ElementaryOperator, outputIndex: I
predictOperator
}

def predictJava[ThatOut: ClassTag, Result: ClassTag](
that: DataQuanta[ThatOut]
): DataQuanta[Result] = {
val predictOperator = new PredictOperator(
implicitly[ClassTag[ThatOut]].runtimeClass,
implicitly[ClassTag[Result]].runtimeClass
)
this.connectTo(predictOperator, 0)
that.connectTo(predictOperator, 1)
predictOperator
}

def dlTraining[ThatOut: ClassTag](
model: DLModel,
option: DLTrainingOperator.Option,
Expand Down Expand Up @@ -616,6 +628,24 @@ class DataQuanta[Out: ClassTag](val operator: ElementaryOperator, outputIndex: I
dlTrainingOperator
}


def dlTrainingJava[ThatOut: ClassTag](
model: DLModel,
option: DLTrainingOperator.Option,
that: DataQuanta[ThatOut]
): DataQuanta[DLModel] = {
val dlTrainingOperator = new DLTrainingOperator(
model,
option,
implicitly[ClassTag[Out]].runtimeClass,
implicitly[ClassTag[ThatOut]].runtimeClass
)

this.connectTo(dlTrainingOperator, 0)
that.connectTo(dlTrainingOperator, 1)
dlTrainingOperator
}

/**
* Feeds this and a further instance into a [[CoGroupOperator]].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import java.util.{Collection => JavaCollection}
import org.apache.wayang.api.graph.{Edge, EdgeDataQuantaBuilder, EdgeDataQuantaBuilderDecorator}
import org.apache.wayang.api.util.{DataQuantaBuilderCache, TypeTrap}
import org.apache.wayang.basic.data.{Record, Tuple2 => RT2}
import org.apache.wayang.basic.operators.{GlobalReduceOperator, LocalCallbackSink, MapOperator, SampleOperator}
import org.apache.wayang.basic.model.{DLModel, Model}
import org.apache.wayang.basic.operators.{DLTrainingOperator, GlobalReduceOperator, LocalCallbackSink, MapOperator, SampleOperator}
import org.apache.wayang.commons.util.profiledb.model.Experiment
import org.apache.wayang.core.function.FunctionDescriptor.{SerializableBiFunction, SerializableBinaryOperator, SerializableFunction, SerializableIntUnaryOperator, SerializablePredicate}
import org.apache.wayang.core.optimizer.ProbabilisticDoubleInterval
Expand Down Expand Up @@ -273,6 +274,30 @@ trait DataQuantaBuilder[+This <: DataQuantaBuilder[_, Out], Out] extends Logging
thatKeyUdf: SerializableFunction[ThatOut, Key]) =
new JoinDataQuantaBuilder(this, that, thisKeyUdf, thatKeyUdf)

/**
* Feed the built [[DataQuanta]] of this and the given instance into a
* [[org.apache.wayang.basic.operators.DLTrainingOperator]].
*
* @param that the other [[DataQuantaBuilder]] to join with
* @param model model for the [[org.apache.wayang.basic.operators.DLTrainingOperator]]
* @param option option for the [[org.apache.wayang.basic.operators.DLTrainingOperator]]
* @return a [[DLTrainingDataQuantaBuilder]]
*/
def dlTraining[ThatOut](that: DataQuantaBuilder[_, ThatOut],
model: DLModel,
option: DLTrainingOperator.Option) =
new DLTrainingDataQuantaBuilder(this, that, model, option)

/**
* Feed the built [[DataQuanta]] of this and the given instance into a
* [[org.apache.wayang.basic.operators.PredictOperator]].
*
* @param that the other [[DataQuantaBuilder]] to join with
* @return a [[PredictDataQuantaBuilder]]
*/
def predict[ThatOut, Result](that: DataQuantaBuilder[_, ThatOut], resultType: Class[Result]) =
new PredictDataQuantaBuilder(this.asInstanceOf[DataQuantaBuilder[_, Model]], that, resultType)

/**
* Feed the built [[DataQuanta]] of this and the given instance into a
* [[org.apache.wayang.basic.operators.CoGroupOperator]].
Expand Down Expand Up @@ -1336,6 +1361,53 @@ class JoinDataQuantaBuilder[In0, In1, Key](inputDataQuanta0: DataQuantaBuilder[_

}

/**
* [[DataQuantaBuilder]] implementation for [[org.apache.wayang.basic.operators.DLTrainingOperator]]s.
*
* @param inputDataQuanta0 [[DataQuantaBuilder]] for the first input [[DataQuanta]]
* @param inputDataQuanta1 [[DataQuantaBuilder]] for the first input [[DataQuanta]]
* @param model model for the [[org.apache.wayang.basic.operators.DLTrainingOperator]]
* @param option option for the [[org.apache.wayang.basic.operators.DLTrainingOperator]]
*/
class DLTrainingDataQuantaBuilder[In0, In1](inputDataQuanta0: DataQuantaBuilder[_, In0],
inputDataQuanta1: DataQuantaBuilder[_, In1],
model: DLModel,
option: DLTrainingOperator.Option)
(implicit javaPlanBuilder: JavaPlanBuilder)
extends BasicDataQuantaBuilder[DLTrainingDataQuantaBuilder[In0, In1], DLModel] {

// Since we are currently not looking at type parameters, we can statically determine the output type.
locally {
this.outputTypeTrap.dataSetType = dataSetType[DLModel]
}

override protected def build =
inputDataQuanta0.dataQuanta()
.dlTrainingJava(model, option, inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag)
}

/**
* [[DataQuantaBuilder]] implementation for [[org.apache.wayang.basic.operators.PredictOperator]]s.
*
* @param inputDataQuanta0 [[DataQuantaBuilder]] for the first input [[DataQuanta]]
* @param inputDataQuanta1 [[DataQuantaBuilder]] for the first input [[DataQuanta]]
*/
class PredictDataQuantaBuilder[In1, Out](inputDataQuanta0: DataQuantaBuilder[_, Model],
inputDataQuanta1: DataQuantaBuilder[_, In1],
outClass: Class[Out])
(implicit javaPlanBuilder: JavaPlanBuilder)
extends BasicDataQuantaBuilder[PredictDataQuantaBuilder[In1, Out], Out] {

// Since we are currently not looking at type parameters, we can statically determine the output type.
locally {
this.outputTypeTrap.dataSetType = dataSetType[Out]
}

override protected def build =
inputDataQuanta0.dataQuanta().
predictJava(inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag, ClassTag.apply(outClass))
}

/**
* [[DataQuantaBuilder]] implementation for [[org.apache.wayang.basic.operators.CoGroupOperator]]s.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public class TensorflowIrisIT {
"Iris-virginica", 2
);

@Ignore
@Test
public void test() {
final Tuple<Operator, Operator> trainSource = fileOperation(TRAIN_PATH, true);
final Tuple<Operator, Operator> testSource = fileOperation(TEST_PATH, false);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.wayang.tests;

import org.apache.wayang.api.*;
import org.apache.wayang.basic.model.DLModel;
import org.apache.wayang.basic.model.op.*;
import org.apache.wayang.basic.model.op.nn.CrossEntropyLoss;
import org.apache.wayang.basic.model.op.nn.Linear;
import org.apache.wayang.basic.model.op.nn.Sigmoid;
import org.apache.wayang.basic.model.optimizer.Adam;
import org.apache.wayang.basic.model.optimizer.Optimizer;
import org.apache.wayang.basic.operators.DLTrainingOperator;
import org.apache.wayang.core.api.WayangContext;
import org.apache.wayang.core.util.Tuple;
import org.apache.wayang.java.Java;
import org.apache.wayang.tensorflow.Tensorflow;
import org.junit.Test;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Random;

/**
* Test the Tensorflow integration with Wayang.
* Note: this test fails on M1 Macs because of Tensorflow-Java incompatibility.
*/
public class TensorflowIrisScalaLikeApiIT {

public static URI TRAIN_PATH = createUri("/iris_train.csv");
public static URI TEST_PATH = createUri("/iris_test.csv");

public static Map<String, Integer> LABEL_MAP = Map.of(
"Iris-setosa", 0,
"Iris-versicolor", 1,
"Iris-virginica", 2
);

@Test
public void test() {
WayangContext wayangContext = new WayangContext()
.with(Java.basicPlugin())
.with(Tensorflow.plugin());

JavaPlanBuilder plan = new JavaPlanBuilder(wayangContext);

final Tuple<DataQuantaBuilder<?, float[]>, DataQuantaBuilder<?, Integer>> trainSource =
fileOperation(plan, TRAIN_PATH, true);
final Tuple<DataQuantaBuilder<?, float[]>, DataQuantaBuilder<?, Integer>> testSource =
fileOperation(plan, TEST_PATH, false);

/* training features */
DataQuantaBuilder<?, float[]> trainXSource = trainSource.field0;

/* training labels */
DataQuantaBuilder<?, Integer> trainYSource = trainSource.field1;

/* test features */
DataQuantaBuilder<?, float[]> testXSource = testSource.field0;

/* test labels */
DataQuantaBuilder<?, Integer> testYSource = testSource.field1;

/* model */
Op l1 = new Linear(4, 32, true);
Op s1 = new Sigmoid();
Op l2 = new Linear(32, 3, true);
s1.with(l1.with(new Input(Input.Type.FEATURES)));
l2.with(s1);

DLModel model = new DLModel(l2);

/* training options */
// 1. loss function
Op criterion = new CrossEntropyLoss(3);
criterion.with(
new Input(Input.Type.PREDICTED, Op.DType.FLOAT32),
new Input(Input.Type.LABEL, Op.DType.INT32)
);

// 2. accuracy calculation function
Op acc = new Mean(0);
acc.with(new Cast(Op.DType.FLOAT32).with(new Eq().with(
new ArgMax(1).with(new Input(Input.Type.PREDICTED, Op.DType.FLOAT32)),
new Input(Input.Type.LABEL, Op.DType.INT32)
)));

// 3. optimizer with learning rate
Optimizer optimizer = new Adam(0.1f);

// 4. batch size
int batchSize = 45;

// 5. epoch
int epoch = 10;

DLTrainingOperator.Option option = new DLTrainingOperator.Option(criterion, optimizer, batchSize, epoch);
option.setAccuracyCalculation(acc);

/* training operator */
DLTrainingDataQuantaBuilder<float[], Integer> trainingOperator =
trainXSource.dlTraining(trainYSource, model, option);

/* predict operator */
PredictDataQuantaBuilder<float[], float[]> predictOperator =
trainingOperator.predict(testXSource, float[].class);

/* map to label */
MapDataQuantaBuilder<float[], Integer> mapOperator = predictOperator.map(array -> {
int maxIdx = 0;
float maxVal = array[0];
for (int i = 1; i < array.length; i++) {
if (array[i] > maxVal) {
maxIdx = i;
maxVal = array[i];
}
}
return maxIdx;
});

/* sink */
List<Integer> predicted = new ArrayList<>(mapOperator.collect());
// fixme: Currently, wayang's scala-like api only supports a single collect,
// so it is not possible to collect multiple result lists in a single plan.
// List<Integer> groundTruth = new ArrayList<>(testYSource.collect());

System.out.println("predicted: " + predicted);
// System.out.println("ground truth: " + groundTruth);

// float success = 0;
// for (int i = 0; i < predicted.size(); i++) {
// if (predicted.get(i).equals(groundTruth.get(i))) {
// success += 1;
// }
// }
// System.out.println("test accuracy: " + success / predicted.size());
}

public static Tuple<DataQuantaBuilder<?, float[]>, DataQuantaBuilder<?, Integer>>
fileOperation(JavaPlanBuilder plan, URI uri, boolean random) {
DataQuantaBuilder<?, String> textFileSource = plan.readTextFile(uri.toString());

if (random) {
Random r = new Random();
textFileSource = textFileSource.sort(e -> r.nextInt());
}

MapDataQuantaBuilder<String, Tuple<float[], Integer>> mapXY = textFileSource.map(line -> {
String[] parts = line.split(",");
float[] x = new float[parts.length - 1];
for (int i = 0; i < x.length; i++) {
x[i] = Float.parseFloat(parts[i]);
}
int y = LABEL_MAP.get(parts[parts.length - 1]);
return new Tuple<>(x, y);
});

MapDataQuantaBuilder<Tuple<float[], Integer>, float[]> mapX = mapXY.map(tuple -> tuple.field0);
MapDataQuantaBuilder<Tuple<float[], Integer>, Integer> mapY = mapXY.map(tuple -> tuple.field1);

return new Tuple<>(mapX, mapY);
}

public static URI createUri(String resourcePath) {
try {
return TensorflowIrisScalaLikeApiIT.class.getResource(resourcePath).toURI();
} catch (URISyntaxException e) {
throw new IllegalArgumentException("Illegal URI.", e);
}
}
}
Loading