Skip to content

Commit 141bcac

Browse files
committed
feat: scala-like api for dl
1 parent 6d9eb47 commit 141bcac

File tree

4 files changed

+157
-72
lines changed

4 files changed

+157
-72
lines changed

wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuanta.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,18 @@ class DataQuanta[Out: ClassTag](val operator: ElementaryOperator, outputIndex: I
588588
predictOperator
589589
}
590590

591+
def predictJava[ThatOut: ClassTag, Result: ClassTag](
592+
that: DataQuanta[ThatOut]
593+
): DataQuanta[Result] = {
594+
val predictOperator = new PredictOperator(
595+
implicitly[ClassTag[ThatOut]].runtimeClass,
596+
implicitly[ClassTag[Result]].runtimeClass
597+
)
598+
this.connectTo(predictOperator, 0)
599+
that.connectTo(predictOperator, 1)
600+
predictOperator
601+
}
602+
591603
def dlTraining[ThatOut: ClassTag](
592604
model: DLModel,
593605
option: DLTrainingOperator.Option,
@@ -616,6 +628,24 @@ class DataQuanta[Out: ClassTag](val operator: ElementaryOperator, outputIndex: I
616628
dlTrainingOperator
617629
}
618630

631+
632+
def dlTrainingJava[ThatOut: ClassTag](
633+
model: DLModel,
634+
option: DLTrainingOperator.Option,
635+
that: DataQuanta[ThatOut]
636+
): DataQuanta[DLModel] = {
637+
val dlTrainingOperator = new DLTrainingOperator(
638+
model,
639+
option,
640+
implicitly[ClassTag[Out]].runtimeClass,
641+
implicitly[ClassTag[ThatOut]].runtimeClass
642+
)
643+
644+
this.connectTo(dlTrainingOperator, 0)
645+
that.connectTo(dlTrainingOperator, 1)
646+
dlTrainingOperator
647+
}
648+
619649
/**
620650
* Feeds this and a further instance into a [[CoGroupOperator]].
621651
*

wayang-api/wayang-api-scala-java/src/main/scala/org/apache/wayang/api/DataQuantaBuilder.scala

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ import java.util.{Collection => JavaCollection}
2727
import org.apache.wayang.api.graph.{Edge, EdgeDataQuantaBuilder, EdgeDataQuantaBuilderDecorator}
2828
import org.apache.wayang.api.util.{DataQuantaBuilderCache, TypeTrap}
2929
import org.apache.wayang.basic.data.{Record, Tuple2 => RT2}
30-
import org.apache.wayang.basic.operators.{GlobalReduceOperator, LocalCallbackSink, MapOperator, SampleOperator}
30+
import org.apache.wayang.basic.model.{DLModel, Model}
31+
import org.apache.wayang.basic.operators.{DLTrainingOperator, GlobalReduceOperator, LocalCallbackSink, MapOperator, SampleOperator}
3132
import org.apache.wayang.commons.util.profiledb.model.Experiment
3233
import org.apache.wayang.core.function.FunctionDescriptor.{SerializableBiFunction, SerializableBinaryOperator, SerializableFunction, SerializableIntUnaryOperator, SerializablePredicate}
3334
import org.apache.wayang.core.optimizer.ProbabilisticDoubleInterval
@@ -273,6 +274,30 @@ trait DataQuantaBuilder[+This <: DataQuantaBuilder[_, Out], Out] extends Logging
273274
thatKeyUdf: SerializableFunction[ThatOut, Key]) =
274275
new JoinDataQuantaBuilder(this, that, thisKeyUdf, thatKeyUdf)
275276

277+
/**
278+
* Feed the built [[DataQuanta]] of this and the given instance into a
279+
* [[org.apache.wayang.basic.operators.DLTrainingOperator]].
280+
*
281+
* @param that the other [[DataQuantaBuilder]] to join with
282+
* @param model model for the [[org.apache.wayang.basic.operators.DLTrainingOperator]]
283+
* @param option option for the [[org.apache.wayang.basic.operators.DLTrainingOperator]]
284+
* @return a [[DLTrainingDataQuantaBuilder]]
285+
*/
286+
def dlTraining[ThatOut](that: DataQuantaBuilder[_, ThatOut],
287+
model: DLModel,
288+
option: DLTrainingOperator.Option) =
289+
new DLTrainingDataQuantaBuilder(this, that, model, option)
290+
291+
/**
292+
* Feed the built [[DataQuanta]] of this and the given instance into a
293+
* [[org.apache.wayang.basic.operators.PredictOperator]].
294+
*
295+
* @param that the other [[DataQuantaBuilder]] to join with
296+
* @return a [[PredictDataQuantaBuilder]]
297+
*/
298+
def predict[ThatOut, Result](that: DataQuantaBuilder[_, ThatOut], resultType: Class[Result]) =
299+
new PredictDataQuantaBuilder(this.asInstanceOf[DataQuantaBuilder[_, Model]], that, resultType)
300+
276301
/**
277302
* Feed the built [[DataQuanta]] of this and the given instance into a
278303
* [[org.apache.wayang.basic.operators.CoGroupOperator]].
@@ -1336,6 +1361,53 @@ class JoinDataQuantaBuilder[In0, In1, Key](inputDataQuanta0: DataQuantaBuilder[_
13361361

13371362
}
13381363

1364+
/**
1365+
* [[DataQuantaBuilder]] implementation for [[org.apache.wayang.basic.operators.DLTrainingOperator]]s.
1366+
*
1367+
* @param inputDataQuanta0 [[DataQuantaBuilder]] for the first input [[DataQuanta]]
1368+
* @param inputDataQuanta1 [[DataQuantaBuilder]] for the first input [[DataQuanta]]
1369+
* @param model model for the [[org.apache.wayang.basic.operators.DLTrainingOperator]]
1370+
* @param option option for the [[org.apache.wayang.basic.operators.DLTrainingOperator]]
1371+
*/
1372+
class DLTrainingDataQuantaBuilder[In0, In1](inputDataQuanta0: DataQuantaBuilder[_, In0],
1373+
inputDataQuanta1: DataQuantaBuilder[_, In1],
1374+
model: DLModel,
1375+
option: DLTrainingOperator.Option)
1376+
(implicit javaPlanBuilder: JavaPlanBuilder)
1377+
extends BasicDataQuantaBuilder[DLTrainingDataQuantaBuilder[In0, In1], DLModel] {
1378+
1379+
// Since we are currently not looking at type parameters, we can statically determine the output type.
1380+
locally {
1381+
this.outputTypeTrap.dataSetType = dataSetType[DLModel]
1382+
}
1383+
1384+
override protected def build =
1385+
inputDataQuanta0.dataQuanta()
1386+
.dlTrainingJava(model, option, inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag)
1387+
}
1388+
1389+
/**
1390+
* [[DataQuantaBuilder]] implementation for [[org.apache.wayang.basic.operators.PredictOperator]]s.
1391+
*
1392+
* @param inputDataQuanta0 [[DataQuantaBuilder]] for the first input [[DataQuanta]]
1393+
* @param inputDataQuanta1 [[DataQuantaBuilder]] for the first input [[DataQuanta]]
1394+
*/
1395+
class PredictDataQuantaBuilder[In1, Out](inputDataQuanta0: DataQuantaBuilder[_, Model],
1396+
inputDataQuanta1: DataQuantaBuilder[_, In1],
1397+
outClass: Class[Out])
1398+
(implicit javaPlanBuilder: JavaPlanBuilder)
1399+
extends BasicDataQuantaBuilder[PredictDataQuantaBuilder[In1, Out], Out] {
1400+
1401+
// Since we are currently not looking at type parameters, we can statically determine the output type.
1402+
locally {
1403+
this.outputTypeTrap.dataSetType = dataSetType[Out]
1404+
}
1405+
1406+
override protected def build =
1407+
inputDataQuanta0.dataQuanta().
1408+
predictJava(inputDataQuanta1.dataQuanta())(inputDataQuanta1.classTag, ClassTag.apply(outClass))
1409+
}
1410+
13391411
/**
13401412
* [[DataQuantaBuilder]] implementation for [[org.apache.wayang.basic.operators.CoGroupOperator]]s.
13411413
*

wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIrisIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ public class TensorflowIrisIT {
5757
"Iris-virginica", 2
5858
);
5959

60-
@Ignore
60+
@Test
6161
public void test() {
6262
final Tuple<Operator, Operator> trainSource = fileOperation(TRAIN_PATH, true);
6363
final Tuple<Operator, Operator> testSource = fileOperation(TEST_PATH, false);

wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIrisScalaLikeApiIT.java

Lines changed: 53 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,20 @@
1818

1919
package org.apache.wayang.tests;
2020

21+
import org.apache.wayang.api.*;
2122
import org.apache.wayang.basic.model.DLModel;
2223
import org.apache.wayang.basic.model.op.*;
2324
import org.apache.wayang.basic.model.op.nn.CrossEntropyLoss;
2425
import org.apache.wayang.basic.model.op.nn.Linear;
2526
import org.apache.wayang.basic.model.op.nn.Sigmoid;
2627
import org.apache.wayang.basic.model.optimizer.Adam;
2728
import org.apache.wayang.basic.model.optimizer.Optimizer;
28-
import org.apache.wayang.basic.operators.*;
29+
import org.apache.wayang.basic.operators.DLTrainingOperator;
2930
import org.apache.wayang.core.api.WayangContext;
30-
import org.apache.wayang.core.plan.wayangplan.Operator;
31-
import org.apache.wayang.core.plan.wayangplan.WayangPlan;
3231
import org.apache.wayang.core.util.Tuple;
3332
import org.apache.wayang.java.Java;
3433
import org.apache.wayang.tensorflow.Tensorflow;
35-
import org.junit.Ignore;
34+
import org.junit.Test;
3635

3736
import java.net.URI;
3837
import java.net.URISyntaxException;
@@ -56,22 +55,30 @@ public class TensorflowIrisScalaLikeApiIT {
5655
"Iris-virginica", 2
5756
);
5857

59-
@Ignore
58+
@Test
6059
public void test() {
61-
final Tuple<Operator, Operator> trainSource = fileOperation(TRAIN_PATH, true);
62-
final Tuple<Operator, Operator> testSource = fileOperation(TEST_PATH, false);
60+
WayangContext wayangContext = new WayangContext()
61+
.with(Java.basicPlugin())
62+
.with(Tensorflow.plugin());
63+
64+
JavaPlanBuilder plan = new JavaPlanBuilder(wayangContext);
65+
66+
final Tuple<DataQuantaBuilder<?, float[]>, DataQuantaBuilder<?, Integer>> trainSource =
67+
fileOperation(plan, TRAIN_PATH, true);
68+
final Tuple<DataQuantaBuilder<?, float[]>, DataQuantaBuilder<?, Integer>> testSource =
69+
fileOperation(plan, TEST_PATH, false);
6370

6471
/* training features */
65-
Operator trainXSource = trainSource.field0;
72+
DataQuantaBuilder<?, float[]> trainXSource = trainSource.field0;
6673

6774
/* training labels */
68-
Operator trainYSource = trainSource.field1;
75+
DataQuantaBuilder<?, Integer> trainYSource = trainSource.field1;
6976

7077
/* test features */
71-
Operator testXSource = testSource.field0;
78+
DataQuantaBuilder<?, float[]> testXSource = testSource.field0;
7279

7380
/* test labels */
74-
Operator testYSource = testSource.field1;
81+
DataQuantaBuilder<?, Integer> testYSource = testSource.field1;
7582

7683
/* model */
7784
Op l1 = new Linear(4, 32, true);
@@ -110,17 +117,15 @@ public void test() {
110117
option.setAccuracyCalculation(acc);
111118

112119
/* training operator */
113-
DLTrainingOperator<float[], Integer> trainingOperator = new DLTrainingOperator<>(
114-
model, option, float[].class, Integer.class
115-
);
120+
DLTrainingDataQuantaBuilder<float[], Integer> trainingOperator =
121+
trainXSource.dlTraining(trainYSource, model, option);
116122

117123
/* predict operator */
118-
PredictOperator<float[], float[]> predictOperator = new PredictOperator<>(
119-
float[].class, float[].class
120-
);
124+
PredictDataQuantaBuilder<float[], float[]> predictOperator =
125+
trainingOperator.predict(testXSource, float[].class);
121126

122127
/* map to label */
123-
MapOperator<float[], Integer> mapOperator = new MapOperator<>(array -> {
128+
MapDataQuantaBuilder<float[], Integer> mapOperator = predictOperator.map(array -> {
124129
int maxIdx = 0;
125130
float maxVal = array[0];
126131
for (int i = 1; i < array.length; i++) {
@@ -130,69 +135,47 @@ public void test() {
130135
}
131136
}
132137
return maxIdx;
133-
}, float[].class, Integer.class);
138+
});
134139

135140
/* sink */
136-
List<Integer> predicted = new ArrayList<>();
137-
LocalCallbackSink<Integer> predictedSink = LocalCallbackSink.createCollectingSink(predicted, Integer.class);
138-
139-
List<Integer> groundTruth = new ArrayList<>();
140-
LocalCallbackSink<Integer> groundTruthSink = LocalCallbackSink.createCollectingSink(groundTruth, Integer.class);
141-
142-
trainXSource.connectTo(0, trainingOperator, 0);
143-
trainYSource.connectTo(0, trainingOperator, 1);
144-
trainingOperator.connectTo(0, predictOperator, 0);
145-
testXSource.connectTo(0, predictOperator, 1);
146-
predictOperator.connectTo(0, mapOperator, 0);
147-
mapOperator.connectTo(0, predictedSink, 0);
148-
testYSource.connectTo(0, groundTruthSink, 0);
149-
150-
WayangPlan wayangPlan = new WayangPlan(predictedSink, groundTruthSink);
151-
152-
WayangContext wayangContext = new WayangContext();
153-
wayangContext.register(Java.basicPlugin());
154-
wayangContext.register(Tensorflow.plugin());
155-
wayangContext.execute(wayangPlan);
141+
List<Integer> predicted = new ArrayList<>(mapOperator.collect());
142+
// fixme: Currently, wayang's scala-like api only supports a single collect,
143+
// so it is not possible to collect multiple result lists in a single plan.
144+
// List<Integer> groundTruth = new ArrayList<>(testYSource.collect());
156145

157146
System.out.println("predicted: " + predicted);
158-
System.out.println("ground truth: " + groundTruth);
159-
160-
float success = 0;
161-
for (int i = 0; i < predicted.size(); i++) {
162-
if (predicted.get(i).equals(groundTruth.get(i))) {
163-
success += 1;
164-
}
165-
}
166-
System.out.println("test accuracy: " + success / predicted.size());
147+
// System.out.println("ground truth: " + groundTruth);
148+
149+
// float success = 0;
150+
// for (int i = 0; i < predicted.size(); i++) {
151+
// if (predicted.get(i).equals(groundTruth.get(i))) {
152+
// success += 1;
153+
// }
154+
// }
155+
// System.out.println("test accuracy: " + success / predicted.size());
167156
}
168157

169-
public static Tuple<Operator, Operator> fileOperation(URI uri, boolean random) {
170-
TextFileSource textFileSource = new TextFileSource(uri.toString());
171-
MapOperator<String, Tuple> mapOperator = new MapOperator<>(line -> {
172-
String[] parts = line.split(",");
173-
float[] x = new float[parts.length - 1];
174-
for (int i = 0; i < x.length; i++) {
175-
x[i] = Float.parseFloat(parts[i]);
176-
}
177-
int y = LABEL_MAP.get(parts[parts.length - 1]);
178-
return new Tuple<>(x, y);
179-
}, String.class, Tuple.class);
180-
181-
MapOperator<Tuple, float[]> mapX = new MapOperator<>(tuple -> (float[]) tuple.field0, Tuple.class, float[].class);
182-
MapOperator<Tuple, Integer> mapY = new MapOperator<>(tuple -> (Integer) tuple.field1, Tuple.class, Integer.class);
158+
public static Tuple<DataQuantaBuilder<?, float[]>, DataQuantaBuilder<?, Integer>>
159+
fileOperation(JavaPlanBuilder plan, URI uri, boolean random) {
160+
DataQuantaBuilder<?, String> textFileSource = plan.readTextFile(uri.toString());
183161

184162
if (random) {
185163
Random r = new Random();
186-
SortOperator<String, Integer> randomOperator = new SortOperator<>(e -> r.nextInt(), String.class, Integer.class);
187-
188-
textFileSource.connectTo(0, randomOperator, 0);
189-
randomOperator.connectTo(0, mapOperator, 0);
190-
} else {
191-
textFileSource.connectTo(0, mapOperator, 0);
164+
textFileSource = textFileSource.sort(e -> r.nextInt());
192165
}
193166

194-
mapOperator.connectTo(0, mapX, 0);
195-
mapOperator.connectTo(0, mapY, 0);
167+
MapDataQuantaBuilder<String, Tuple<float[], Integer>> mapXY = textFileSource.map(line -> {
168+
String[] parts = line.split(",");
169+
float[] x = new float[parts.length - 1];
170+
for (int i = 0; i < x.length; i++) {
171+
x[i] = Float.parseFloat(parts[i]);
172+
}
173+
int y = LABEL_MAP.get(parts[parts.length - 1]);
174+
return new Tuple<>(x, y);
175+
});
176+
177+
MapDataQuantaBuilder<Tuple<float[], Integer>, float[]> mapX = mapXY.map(tuple -> tuple.field0);
178+
MapDataQuantaBuilder<Tuple<float[], Integer>, Integer> mapY = mapXY.map(tuple -> tuple.field1);
196179

197180
return new Tuple<>(mapX, mapY);
198181
}

0 commit comments

Comments
 (0)