Skip to content

Commit 6980739

Browse files
committed
updating Tensorflow versioning
1 parent eb4b228 commit 6980739

File tree

6 files changed

+92
-9
lines changed

6 files changed

+92
-9
lines changed

wayang-api/wayang-api-scala-java/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
<properties>
3636
<java-module-name>org.apache.wayang.api</java-module-name>
37-
<tensorflow.version>0.4.2</tensorflow.version>
37+
<tensorflow.version>1.0.0-rc.2</tensorflow.version>
3838
</properties>
3939

4040
<dependencyManagement>

wayang-platforms/wayang-java/src/test/java/org/apache/wayang/java/operators/JavaTextFileSourceTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.junit.Assert;
3535
import org.junit.Before;
3636
import org.junit.Test;
37+
import org.junit.Ignore;
3738

3839
import java.io.File;
3940
import java.io.IOException;
@@ -97,7 +98,7 @@ public void testReadLocalFile() throws IOException, URISyntaxException {
9798
}
9899
}
99100

100-
// @Test
101+
@Ignore
101102
/**
102103
* Requires a local HTTP Server running, in the project root ...
103104
*
@@ -128,7 +129,7 @@ public void testReadRemoteFileHTTP() throws IOException, URISyntaxException {
128129
}
129130
}
130131

131-
@Test
132+
@Ignore
132133
public void testReadRemoteFileHTTPS() throws IOException, URISyntaxException {
133134
final String testFileURL = "https://kamir.solidcommunity.net/public/ecolytiq-sustainability-profile/profile2.ttl";
134135

wayang-platforms/wayang-tensorflow/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
<properties>
3737
<maven.compiler.source>11</maven.compiler.source>
3838
<maven.compiler.target>11</maven.compiler.target>
39-
<tensorflow.version>0.4.2</tensorflow.version>
39+
<tensorflow.version>1.0.0-rc.2</tensorflow.version>
4040
</properties>
4141

4242
<dependencies>
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.wayang.tensorflow.model;
20+
21+
import org.apache.wayang.basic.model.DLModel;
22+
import org.apache.wayang.basic.model.op.*;
23+
import org.apache.wayang.basic.model.op.nn.CrossEntropyLoss;
24+
import org.apache.wayang.basic.model.op.nn.Linear;
25+
import org.apache.wayang.basic.model.op.nn.Sigmoid;
26+
import org.apache.wayang.basic.model.optimizer.GradientDescent;
27+
import org.apache.wayang.basic.model.optimizer.Optimizer;
28+
import org.junit.Test;
29+
import org.junit.Ignore;
30+
import org.tensorflow.ndarray.FloatNdArray;
31+
import org.tensorflow.ndarray.IntNdArray;
32+
import org.tensorflow.ndarray.NdArrays;
33+
import org.tensorflow.ndarray.Shape;
34+
import org.tensorflow.op.Ops;
35+
import org.tensorflow.types.TFloat32;
36+
import org.tensorflow.types.TInt32;
37+
public class TensorflowModelTest {
38+
@Ignore
39+
public void test() {
40+
FloatNdArray x = NdArrays.ofFloats(Shape.of(6, 4))
41+
.set(NdArrays.vectorOf(5.1f, 3.5f, 1.4f, 0.2f), 0)
42+
.set(NdArrays.vectorOf(4.9f, 3.0f, 1.4f, 0.2f), 1)
43+
.set(NdArrays.vectorOf(6.9f, 3.1f, 4.9f, 1.5f), 2)
44+
.set(NdArrays.vectorOf(5.5f, 2.3f, 4.0f, 1.3f), 3)
45+
.set(NdArrays.vectorOf(5.8f, 2.7f, 5.1f, 1.9f), 4)
46+
.set(NdArrays.vectorOf(6.7f, 3.3f, 5.7f, 2.5f), 5)
47+
;
48+
IntNdArray y = NdArrays.vectorOf(0, 0, 1, 1, 2, 2);
49+
Op l1 = new Linear(4, 64, true);
50+
Op s1 = new Sigmoid();
51+
Op l2 = new Linear(64, 3, true);
52+
s1.with(l1.with(new Input(Input.Type.FEATURES)));
53+
l2.with(s1);
54+
DLModel model = new DLModel(l2);
55+
Op criterion = new CrossEntropyLoss(3);
56+
criterion.with(
57+
new Input(Input.Type.PREDICTED, Op.DType.FLOAT32),
58+
new Input(Input.Type.LABEL, Op.DType.INT32)
59+
);
60+
Op acc = new Mean(0);
61+
acc.with(new Cast(Op.DType.FLOAT32).with(new Eq().with(
62+
new ArgMax(1).with(new Input(Input.Type.PREDICTED, Op.DType.FLOAT32)),
63+
new Input(Input.Type.LABEL, Op.DType.INT32)
64+
)));
65+
Optimizer optimizer = new GradientDescent(0.02f);
66+
try (TensorflowModel tfModel = new TensorflowModel(model, criterion, optimizer, acc)) {
67+
System.out.println(tfModel.getOut().getName());
68+
tfModel.train(x, y, 100, 6);
69+
TFloat32 predicted = tfModel.predict(x);
70+
Ops tf = Ops.create();
71+
org.tensorflow.op.math.ArgMax<TInt32> argMax = tf.math.argMax(tf.constantOf(predicted), tf.constant(1), TInt32.class);
72+
final TInt32 tensor = argMax.asTensor();
73+
System.out.print("[ ");
74+
for (int i = 0; i < tensor.shape().size(0); i++) {
75+
System.out.print(tensor.getInt(i) + " ");
76+
}
77+
System.out.println("]");
78+
}
79+
System.out.println();
80+
}
81+
}

wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIntegrationIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
import org.apache.wayang.core.plan.wayangplan.WayangPlan;
3131
import org.apache.wayang.java.Java;
3232
import org.apache.wayang.tensorflow.Tensorflow;
33-
import org.junit.Ignore;
3433
import org.junit.Test;
34+
import org.junit.Ignore;
3535

3636
import java.util.ArrayList;
3737
import java.util.Arrays;

wayang-tests-integration/src/test/java/org/apache/wayang/tests/TensorflowIrisIT.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,23 @@
2424
import org.apache.wayang.basic.model.op.nn.Linear;
2525
import org.apache.wayang.basic.model.op.nn.Sigmoid;
2626
import org.apache.wayang.basic.model.optimizer.Adam;
27-
import org.apache.wayang.basic.model.optimizer.GradientDescent;
2827
import org.apache.wayang.basic.model.optimizer.Optimizer;
2928
import org.apache.wayang.basic.operators.*;
3029
import org.apache.wayang.core.api.WayangContext;
3130
import org.apache.wayang.core.plan.wayangplan.Operator;
3231
import org.apache.wayang.core.plan.wayangplan.WayangPlan;
3332
import org.apache.wayang.core.util.Tuple;
34-
import org.apache.wayang.core.util.WayangCollections;
3533
import org.apache.wayang.java.Java;
3634
import org.apache.wayang.tensorflow.Tensorflow;
37-
import org.junit.Ignore;
3835
import org.junit.Test;
36+
import org.junit.Ignore;
3937

4038
import java.net.URI;
4139
import java.net.URISyntaxException;
42-
import java.util.*;
40+
import java.util.ArrayList;
41+
import java.util.List;
42+
import java.util.Map;
43+
import java.util.Random;
4344

4445
/**
4546
* Test the Tensorflow integration with Wayang.

0 commit comments

Comments
 (0)