|
22 | 22 | import org.apache.wayang.basic.model.op.Input; |
23 | 23 | import org.apache.wayang.basic.model.op.Op; |
24 | 24 | import org.apache.wayang.basic.model.optimizer.Optimizer; |
25 | | -import org.tensorflow.Graph; |
26 | | -import org.tensorflow.Operand; |
27 | | -import org.tensorflow.Session; |
28 | | -import org.tensorflow.Tensor; |
| 25 | +import org.tensorflow.*; |
29 | 26 | import org.tensorflow.ndarray.*; |
30 | 27 | import org.tensorflow.ndarray.index.Indices; |
31 | 28 | import org.tensorflow.op.Ops; |
@@ -122,14 +119,12 @@ void train(XT x, YT y, int epoch, int batchSize) { |
122 | 119 | if (accuracyCalculation != null) { |
123 | 120 | runner.fetch(accuracyCalculation.getName()); |
124 | 121 | } |
125 | | - List<Tensor> ret = runner.run(); |
126 | | - try (TFloat32 loss = (TFloat32) ret.get(0)) { |
| 122 | + try (Result ret = runner.run()) { |
| 123 | + TFloat32 loss = (TFloat32) ret.get(0); |
127 | 124 | System.out.printf("[epoch %d, batch %d] loss: %f ", i + 1, start / batchSize + 1, loss.getFloat()); |
128 | | - } |
129 | | - if (accuracyCalculation != null) { |
130 | | - try (TFloat32 acc = (TFloat32) ret.get(1)) { |
131 | | - System.out.printf("accuracy: %f ", acc.getFloat()); |
132 | | - } |
| 125 | + |
| 126 | + TFloat32 acc = (TFloat32) ret.get(1); |
| 127 | + System.out.printf("accuracy: %f ", acc.getFloat()); |
133 | 128 | } |
134 | 129 | System.out.println(); |
135 | 130 | } |
|
0 commit comments