Skip to content

Commit b2ec280

Browse files
committed
updating Tensorflow versioning
1 parent 6980739 commit b2ec280

File tree

1 file changed

+6
-11
lines changed
  • wayang-platforms/wayang-tensorflow/src/main/java/org/apache/wayang/tensorflow/model

1 file changed

+6
-11
lines changed

wayang-platforms/wayang-tensorflow/src/main/java/org/apache/wayang/tensorflow/model/TensorflowModel.java

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222
import org.apache.wayang.basic.model.op.Input;
2323
import org.apache.wayang.basic.model.op.Op;
2424
import org.apache.wayang.basic.model.optimizer.Optimizer;
25-
import org.tensorflow.Graph;
26-
import org.tensorflow.Operand;
27-
import org.tensorflow.Session;
28-
import org.tensorflow.Tensor;
25+
import org.tensorflow.*;
2926
import org.tensorflow.ndarray.*;
3027
import org.tensorflow.ndarray.index.Indices;
3128
import org.tensorflow.op.Ops;
@@ -122,14 +119,12 @@ void train(XT x, YT y, int epoch, int batchSize) {
122119
if (accuracyCalculation != null) {
123120
runner.fetch(accuracyCalculation.getName());
124121
}
125-
List<Tensor> ret = runner.run();
126-
try (TFloat32 loss = (TFloat32) ret.get(0)) {
122+
try (Result ret = runner.run()) {
123+
TFloat32 loss = (TFloat32) ret.get(0);
127124
System.out.printf("[epoch %d, batch %d] loss: %f ", i + 1, start / batchSize + 1, loss.getFloat());
128-
}
129-
if (accuracyCalculation != null) {
130-
try (TFloat32 acc = (TFloat32) ret.get(1)) {
131-
System.out.printf("accuracy: %f ", acc.getFloat());
132-
}
125+
126+
TFloat32 acc = (TFloat32) ret.get(1);
127+
System.out.printf("accuracy: %f ", acc.getFloat());
133128
}
134129
System.out.println();
135130
}

0 commit comments

Comments
 (0)