Skip to content

Commit 391ec1b

Browse files
authored
Merge pull request #897 from rnett/rn_sd_examples
[WIP] SameDiff examples
2 parents be13c23 + 337e0d9 commit 391ec1b

File tree

5 files changed

+690
-0
lines changed

5 files changed

+690
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
package org.deeplearning4j.examples.samediff.tfimport;
2+
3+
import java.io.File;
4+
import java.io.IOException;
5+
import java.net.URL;
6+
import java.util.Arrays;
7+
import org.apache.commons.io.FilenameUtils;
8+
import org.datavec.image.loader.ImageLoader;
9+
import org.deeplearning4j.zoo.model.helper.InceptionResNetHelper;
10+
import org.deeplearning4j.zoo.util.imagenet.ImageNetLabels;
11+
import org.nd4j.autodiff.samediff.SameDiff;
12+
import org.nd4j.linalg.api.ndarray.INDArray;
13+
import org.nd4j.linalg.api.ops.DynamicCustomOp;
14+
import org.nd4j.linalg.factory.Nd4j;
15+
import org.nd4j.linalg.indexing.INDArrayIndex;
16+
import org.nd4j.linalg.indexing.NDArrayIndex;
17+
import org.nd4j.resources.Downloader;
18+
19+
/**
20+
* This example shows the ability to import and use Tensorflow models, specifically mobilenet, and use them for inference.
21+
*/
22+
public class SameDiffTFImportMobileNetExample {
23+
24+
public static void main(String[] args) throws Exception {
25+
26+
// download and extract a tensorflow frozen model file (usually a .pb file)
27+
File modelFile = downloadModel();
28+
29+
// import the frozen model into a SameDiff instance
30+
SameDiff sd = SameDiff.importFrozenTF(modelFile);
31+
32+
System.out.println(sd.summary());
33+
34+
System.out.println("\n\n");
35+
36+
// get the image from https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/img/image2.jpg for testing
37+
INDArray testImage = getTestImage();
38+
39+
// preprocess image with inception preprocessing
40+
INDArray preprocessedImage = inceptionPreprocessing(testImage, 224, 224);
41+
42+
// Input and output names are found by looking at sd.summary() (printed earlyer).
43+
// The input variable is the output of no ops, and the output variable is the input of no ops.
44+
45+
// Alternatively, you can use sd.outputs() and sd.inputs().
46+
47+
System.out.println("Input: " + sd.inputs());
48+
System.out.println("Output: " + sd.outputs());
49+
50+
// Do inference for a single batch.
51+
INDArray out = sd.batchOutput()
52+
.input("input", preprocessedImage)
53+
.output("MobilenetV2/Predictions/Reshape_1")
54+
.execSingle();
55+
56+
// ignore label 0 (the background label)
57+
out = out.get(NDArrayIndex.all(), NDArrayIndex.interval(1, 1001));
58+
59+
// get the readable label for the classes
60+
String label = new ImageNetLabels().decodePredictions(out);
61+
62+
System.out.println("Predictions: " + label);
63+
64+
}
65+
66+
67+
68+
public static String MODEL_URL = "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz";
69+
70+
// download and extract the model file in the ~/dl4j-examples-data directory used by other examples
71+
public static File downloadModel() throws Exception{
72+
String dataDir = FilenameUtils.concat(System.getProperty("user.home"), "dl4j-examples-data/tf_resnet");
73+
String modelFile = FilenameUtils.concat(dataDir, "mobilenet_v2_1.0_224.tgz");
74+
75+
File frozenFile = new File(FilenameUtils.concat(dataDir, "mobilenet_v2_1.0_224_frozen.pb"));
76+
77+
if(frozenFile.exists()){
78+
return frozenFile;
79+
}
80+
81+
Downloader.downloadAndExtract("tf_resnet", new URL(MODEL_URL), new File(modelFile), new File(dataDir), "519bba7052fd279c66d2a28dc3f51f46", 5);
82+
83+
return frozenFile;
84+
}
85+
86+
// gets the image we use to test the network.
87+
// This isn't a single class ImageNet image, so it won't do very well, but it will at least classify it as a dog or a cat.
88+
public static INDArray getTestImage() throws IOException {
89+
URL url = new URL("https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/img/image2.jpg?raw=true");
90+
return new ImageLoader(358, 500, 3).asMatrix(url.openStream());
91+
}
92+
93+
/**
94+
* Does inception preprocessing. Takes an image with shape [c, h, w]
95+
* and returns an image with shape [1, height, width, c].
96+
*
97+
* Eventually this will be made part of DL4J.
98+
*
99+
* @param height the height to resize to
100+
* @param width the width to resize to
101+
*/
102+
public static INDArray inceptionPreprocessing(INDArray img, int height, int width){
103+
// add batch dimension
104+
img = Nd4j.expandDims(img, 0);
105+
106+
// change to channels-last
107+
img = img.permute(0, 2, 3, 1);
108+
109+
// normalize to 0-1
110+
img = img.div(256);
111+
112+
// resize
113+
INDArray preprocessedImage = Nd4j.createUninitialized(1, height, width, 3);
114+
115+
DynamicCustomOp op = DynamicCustomOp.builder("resize_bilinear")
116+
.addInputs(img)
117+
.addOutputs(preprocessedImage)
118+
.addIntegerArguments(height, width).build();
119+
Nd4j.exec(op);
120+
121+
// finish preprocessing
122+
preprocessedImage = preprocessedImage.sub(0.5);
123+
preprocessedImage = preprocessedImage.mul(2);
124+
return preprocessedImage;
125+
}
126+
127+
}

0 commit comments

Comments
 (0)