1818
1919package org .apache .wayang .tests ;
2020
21+ import org .apache .wayang .api .*;
2122import org .apache .wayang .basic .model .DLModel ;
2223import org .apache .wayang .basic .model .op .*;
2324import org .apache .wayang .basic .model .op .nn .CrossEntropyLoss ;
2425import org .apache .wayang .basic .model .op .nn .Linear ;
2526import org .apache .wayang .basic .model .op .nn .Sigmoid ;
2627import org .apache .wayang .basic .model .optimizer .Adam ;
2728import org .apache .wayang .basic .model .optimizer .Optimizer ;
28- import org .apache .wayang .basic .operators .* ;
29+ import org .apache .wayang .basic .operators .DLTrainingOperator ;
2930import org .apache .wayang .core .api .WayangContext ;
30- import org .apache .wayang .core .plan .wayangplan .Operator ;
31- import org .apache .wayang .core .plan .wayangplan .WayangPlan ;
3231import org .apache .wayang .core .util .Tuple ;
3332import org .apache .wayang .java .Java ;
3433import org .apache .wayang .tensorflow .Tensorflow ;
35- import org .junit .Ignore ;
34+ import org .junit .Test ;
3635
3736import java .net .URI ;
3837import java .net .URISyntaxException ;
@@ -56,22 +55,30 @@ public class TensorflowIrisScalaLikeApiIT {
5655 "Iris-virginica" , 2
5756 );
5857
59- @ Ignore
58+ @ Test
6059 public void test () {
61- final Tuple <Operator , Operator > trainSource = fileOperation (TRAIN_PATH , true );
62- final Tuple <Operator , Operator > testSource = fileOperation (TEST_PATH , false );
60+ WayangContext wayangContext = new WayangContext ()
61+ .with (Java .basicPlugin ())
62+ .with (Tensorflow .plugin ());
63+
64+ JavaPlanBuilder plan = new JavaPlanBuilder (wayangContext );
65+
66+ final Tuple <DataQuantaBuilder <?, float []>, DataQuantaBuilder <?, Integer >> trainSource =
67+ fileOperation (plan , TRAIN_PATH , true );
68+ final Tuple <DataQuantaBuilder <?, float []>, DataQuantaBuilder <?, Integer >> testSource =
69+ fileOperation (plan , TEST_PATH , false );
6370
6471 /* training features */
65- Operator trainXSource = trainSource .field0 ;
72+ DataQuantaBuilder <?, float []> trainXSource = trainSource .field0 ;
6673
6774 /* training labels */
68- Operator trainYSource = trainSource .field1 ;
75+ DataQuantaBuilder <?, Integer > trainYSource = trainSource .field1 ;
6976
7077 /* test features */
71- Operator testXSource = testSource .field0 ;
78+ DataQuantaBuilder <?, float []> testXSource = testSource .field0 ;
7279
7380 /* test labels */
74- Operator testYSource = testSource .field1 ;
81+ DataQuantaBuilder <?, Integer > testYSource = testSource .field1 ;
7582
7683 /* model */
7784 Op l1 = new Linear (4 , 32 , true );
@@ -110,17 +117,15 @@ public void test() {
110117 option .setAccuracyCalculation (acc );
111118
112119 /* training operator */
113- DLTrainingOperator <float [], Integer > trainingOperator = new DLTrainingOperator <>(
114- model , option , float [].class , Integer .class
115- );
120+ DLTrainingDataQuantaBuilder <float [], Integer > trainingOperator =
121+ trainXSource .dlTraining (trainYSource , model , option );
116122
117123 /* predict operator */
118- PredictOperator <float [], float []> predictOperator = new PredictOperator <>(
119- float [].class , float [].class
120- );
124+ PredictDataQuantaBuilder <float [], float []> predictOperator =
125+ trainingOperator .predict (testXSource , float [].class );
121126
122127 /* map to label */
123- MapOperator <float [], Integer > mapOperator = new MapOperator <> (array -> {
128+ MapDataQuantaBuilder <float [], Integer > mapOperator = predictOperator . map (array -> {
124129 int maxIdx = 0 ;
125130 float maxVal = array [0 ];
126131 for (int i = 1 ; i < array .length ; i ++) {
@@ -130,69 +135,47 @@ public void test() {
130135 }
131136 }
132137 return maxIdx ;
133- }, float []. class , Integer . class );
138+ });
134139
135140 /* sink */
136- List <Integer > predicted = new ArrayList <>();
137- LocalCallbackSink <Integer > predictedSink = LocalCallbackSink .createCollectingSink (predicted , Integer .class );
138-
139- List <Integer > groundTruth = new ArrayList <>();
140- LocalCallbackSink <Integer > groundTruthSink = LocalCallbackSink .createCollectingSink (groundTruth , Integer .class );
141-
142- trainXSource .connectTo (0 , trainingOperator , 0 );
143- trainYSource .connectTo (0 , trainingOperator , 1 );
144- trainingOperator .connectTo (0 , predictOperator , 0 );
145- testXSource .connectTo (0 , predictOperator , 1 );
146- predictOperator .connectTo (0 , mapOperator , 0 );
147- mapOperator .connectTo (0 , predictedSink , 0 );
148- testYSource .connectTo (0 , groundTruthSink , 0 );
149-
150- WayangPlan wayangPlan = new WayangPlan (predictedSink , groundTruthSink );
151-
152- WayangContext wayangContext = new WayangContext ();
153- wayangContext .register (Java .basicPlugin ());
154- wayangContext .register (Tensorflow .plugin ());
155- wayangContext .execute (wayangPlan );
141+ List <Integer > predicted = new ArrayList <>(mapOperator .collect ());
142+ // fixme: Currently, wayang's scala-like api only supports a single collect,
143+ // so it is not possible to collect multiple result lists in a single plan.
144+ // List<Integer> groundTruth = new ArrayList<>(testYSource.collect());
156145
157146 System .out .println ("predicted: " + predicted );
158- System .out .println ("ground truth: " + groundTruth );
159-
160- float success = 0 ;
161- for (int i = 0 ; i < predicted .size (); i ++) {
162- if (predicted .get (i ).equals (groundTruth .get (i ))) {
163- success += 1 ;
164- }
165- }
166- System .out .println ("test accuracy: " + success / predicted .size ());
147+ // System.out.println("ground truth: " + groundTruth);
148+
149+ // float success = 0;
150+ // for (int i = 0; i < predicted.size(); i++) {
151+ // if (predicted.get(i).equals(groundTruth.get(i))) {
152+ // success += 1;
153+ // }
154+ // }
155+ // System.out.println("test accuracy: " + success / predicted.size());
167156 }
168157
169- public static Tuple <Operator , Operator > fileOperation (URI uri , boolean random ) {
170- TextFileSource textFileSource = new TextFileSource (uri .toString ());
171- MapOperator <String , Tuple > mapOperator = new MapOperator <>(line -> {
172- String [] parts = line .split ("," );
173- float [] x = new float [parts .length - 1 ];
174- for (int i = 0 ; i < x .length ; i ++) {
175- x [i ] = Float .parseFloat (parts [i ]);
176- }
177- int y = LABEL_MAP .get (parts [parts .length - 1 ]);
178- return new Tuple <>(x , y );
179- }, String .class , Tuple .class );
180-
181- MapOperator <Tuple , float []> mapX = new MapOperator <>(tuple -> (float []) tuple .field0 , Tuple .class , float [].class );
182- MapOperator <Tuple , Integer > mapY = new MapOperator <>(tuple -> (Integer ) tuple .field1 , Tuple .class , Integer .class );
158+ public static Tuple <DataQuantaBuilder <?, float []>, DataQuantaBuilder <?, Integer >>
159+ fileOperation (JavaPlanBuilder plan , URI uri , boolean random ) {
160+ DataQuantaBuilder <?, String > textFileSource = plan .readTextFile (uri .toString ());
183161
184162 if (random ) {
185163 Random r = new Random ();
186- SortOperator <String , Integer > randomOperator = new SortOperator <>(e -> r .nextInt (), String .class , Integer .class );
187-
188- textFileSource .connectTo (0 , randomOperator , 0 );
189- randomOperator .connectTo (0 , mapOperator , 0 );
190- } else {
191- textFileSource .connectTo (0 , mapOperator , 0 );
164+ textFileSource = textFileSource .sort (e -> r .nextInt ());
192165 }
193166
194- mapOperator .connectTo (0 , mapX , 0 );
195- mapOperator .connectTo (0 , mapY , 0 );
167+ MapDataQuantaBuilder <String , Tuple <float [], Integer >> mapXY = textFileSource .map (line -> {
168+ String [] parts = line .split ("," );
169+ float [] x = new float [parts .length - 1 ];
170+ for (int i = 0 ; i < x .length ; i ++) {
171+ x [i ] = Float .parseFloat (parts [i ]);
172+ }
173+ int y = LABEL_MAP .get (parts [parts .length - 1 ]);
174+ return new Tuple <>(x , y );
175+ });
176+
177+ MapDataQuantaBuilder <Tuple <float [], Integer >, float []> mapX = mapXY .map (tuple -> tuple .field0 );
178+ MapDataQuantaBuilder <Tuple <float [], Integer >, Integer > mapY = mapXY .map (tuple -> tuple .field1 );
196179
197180 return new Tuple <>(mapX , mapY );
198181 }
0 commit comments