Skip to content

Commit e4be2ad

Browse files
committed
tfimport.
Signed-off-by: Robert Altena <[email protected]>
1 parent cb1519e commit e4be2ad

File tree

2 files changed

+15
-22
lines changed

2 files changed

+15
-22
lines changed

dl4j-examples/src/main/java/org/deeplearning4j/examples/samediff/tfimport/SameDiffTFImportMobileNetExample.java

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
11
package org.deeplearning4j.examples.samediff.tfimport;
22

3-
import java.io.File;
4-
import java.io.IOException;
5-
import java.net.URL;
6-
import java.util.Arrays;
73
import org.apache.commons.io.FilenameUtils;
84
import org.datavec.image.loader.ImageLoader;
9-
import org.deeplearning4j.zoo.model.helper.InceptionResNetHelper;
105
import org.deeplearning4j.zoo.util.imagenet.ImageNetLabels;
116
import org.nd4j.autodiff.samediff.SameDiff;
127
import org.nd4j.linalg.api.ndarray.INDArray;
138
import org.nd4j.linalg.api.ops.DynamicCustomOp;
149
import org.nd4j.linalg.factory.Nd4j;
15-
import org.nd4j.linalg.indexing.INDArrayIndex;
1610
import org.nd4j.linalg.indexing.NDArrayIndex;
1711
import org.nd4j.resources.Downloader;
1812

13+
import java.io.File;
14+
import java.io.IOException;
15+
import java.net.URL;
16+
1917
/**
2018
* This example shows the ability to import and use Tensorflow models, specifically mobilenet, and use them for inference.
2119
*/
@@ -64,11 +62,8 @@ public static void main(String[] args) throws Exception {
6462
}
6563

6664

67-
68-
public static String MODEL_URL = "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz";
69-
7065
// download and extract the model file in the ~/dl4j-examples-data directory used by other examples
71-
public static File downloadModel() throws Exception{
66+
static File downloadModel() throws Exception{
7267
String dataDir = FilenameUtils.concat(System.getProperty("user.home"), "dl4j-examples-data/tf_resnet");
7368
String modelFile = FilenameUtils.concat(dataDir, "mobilenet_v2_1.0_224.tgz");
7469

@@ -78,14 +73,15 @@ public static File downloadModel() throws Exception{
7873
return frozenFile;
7974
}
8075

76+
String MODEL_URL = "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz";
8177
Downloader.downloadAndExtract("tf_resnet", new URL(MODEL_URL), new File(modelFile), new File(dataDir), "519bba7052fd279c66d2a28dc3f51f46", 5);
8278

8379
return frozenFile;
8480
}
8581

8682
// gets the image we use to test the network.
8783
// This isn't a single class ImageNet image, so it won't do very well, but it will at least classify it as a dog or a cat.
88-
public static INDArray getTestImage() throws IOException {
84+
private static INDArray getTestImage() throws IOException {
8985
URL url = new URL("https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/img/image2.jpg?raw=true");
9086
return new ImageLoader(358, 500, 3).asMatrix(url.openStream());
9187
}
@@ -99,7 +95,8 @@ public static INDArray getTestImage() throws IOException {
9995
* @param height the height to resize to
10096
* @param width the width to resize to
10197
*/
102-
public static INDArray inceptionPreprocessing(INDArray img, int height, int width){
98+
@SuppressWarnings("SameParameterValue")
99+
private static INDArray inceptionPreprocessing(INDArray img, int height, int width){
103100
// add batch dimension
104101
img = Nd4j.expandDims(img, 0);
105102

dl4j-examples/src/main/java/org/deeplearning4j/examples/samediff/tfimport/SameDiffTransferLearningExample.java

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
package org.deeplearning4j.examples.samediff.tfimport;
22

3-
import java.io.File;
4-
import java.util.Arrays;
5-
import java.util.Collections;
6-
import java.util.List;
73
import org.deeplearning4j.datasets.fetchers.DataSetType;
84
import org.deeplearning4j.datasets.iterator.impl.Cifar10DataSetIterator;
9-
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
105
import org.deeplearning4j.examples.samediff.training.SameDiffCustomListenerExample;
116
import org.deeplearning4j.examples.samediff.training.SameDiffMNISTTrainingExample;
127
import org.nd4j.autodiff.listeners.At;
@@ -21,25 +16,25 @@
2116
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
2217
import org.nd4j.autodiff.samediff.transform.GraphTransformUtil;
2318
import org.nd4j.autodiff.samediff.transform.OpPredicate;
24-
import org.nd4j.autodiff.samediff.transform.SubGraph;
2519
import org.nd4j.autodiff.samediff.transform.SubGraphPredicate;
26-
import org.nd4j.autodiff.samediff.transform.SubGraphProcessor;
2720
import org.nd4j.evaluation.classification.Evaluation;
2821
import org.nd4j.evaluation.classification.Evaluation.Metric;
2922
import org.nd4j.linalg.api.buffer.DataType;
3023
import org.nd4j.linalg.api.ndarray.INDArray;
3124
import org.nd4j.linalg.api.ops.DynamicCustomOp;
32-
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
33-
import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D;
3425
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D;
3526
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
36-
import org.nd4j.linalg.dataset.api.DataSet;
3727
import org.nd4j.linalg.dataset.api.MultiDataSet;
3828
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
3929
import org.nd4j.linalg.factory.Nd4j;
4030
import org.nd4j.linalg.learning.config.Adam;
4131
import org.nd4j.weightinit.impl.XavierInitScheme;
4232

33+
import java.io.File;
34+
import java.util.Arrays;
35+
import java.util.Collections;
36+
import java.util.List;
37+
4338
/**
4439
* This is an example of doing transfer learning by importing a tensorflow model of mobilenet and replacing the last layer.
4540
*
@@ -50,6 +45,7 @@
5045
* See {@link SameDiffCustomListenerExample} for an example of how to use custom listeners (we use one here to find the shapes of an activation).
5146
*
5247
*/
48+
@SuppressWarnings("unused") //
5349
public class SameDiffTransferLearningExample {
5450

5551
public static void main(String[] args) throws Exception {

0 commit comments

Comments
 (0)