1
1
package org .deeplearning4j .examples .samediff .tfimport ;
2
2
3
- import java .io .File ;
4
- import java .io .IOException ;
5
- import java .net .URL ;
6
- import java .util .Arrays ;
7
3
import org .apache .commons .io .FilenameUtils ;
8
4
import org .datavec .image .loader .ImageLoader ;
9
- import org .deeplearning4j .zoo .model .helper .InceptionResNetHelper ;
10
5
import org .deeplearning4j .zoo .util .imagenet .ImageNetLabels ;
11
6
import org .nd4j .autodiff .samediff .SameDiff ;
12
7
import org .nd4j .linalg .api .ndarray .INDArray ;
13
8
import org .nd4j .linalg .api .ops .DynamicCustomOp ;
14
9
import org .nd4j .linalg .factory .Nd4j ;
15
- import org .nd4j .linalg .indexing .INDArrayIndex ;
16
10
import org .nd4j .linalg .indexing .NDArrayIndex ;
17
11
import org .nd4j .resources .Downloader ;
18
12
13
+ import java .io .File ;
14
+ import java .io .IOException ;
15
+ import java .net .URL ;
16
+
19
17
/**
20
18
* This example shows the ability to import and use Tensorflow models, specifically mobilenet, and use them for inference.
21
19
*/
@@ -64,11 +62,8 @@ public static void main(String[] args) throws Exception {
64
62
}
65
63
66
64
67
-
68
- public static String MODEL_URL = "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz" ;
69
-
70
65
// download and extract the model file in the ~/dl4j-examples-data directory used by other examples
71
- public static File downloadModel () throws Exception {
66
+ static File downloadModel () throws Exception {
72
67
String dataDir = FilenameUtils .concat (System .getProperty ("user.home" ), "dl4j-examples-data/tf_resnet" );
73
68
String modelFile = FilenameUtils .concat (dataDir , "mobilenet_v2_1.0_224.tgz" );
74
69
@@ -78,14 +73,15 @@ public static File downloadModel() throws Exception{
78
73
return frozenFile ;
79
74
}
80
75
76
+ String MODEL_URL = "https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz" ;
81
77
Downloader .downloadAndExtract ("tf_resnet" , new URL (MODEL_URL ), new File (modelFile ), new File (dataDir ), "519bba7052fd279c66d2a28dc3f51f46" , 5 );
82
78
83
79
return frozenFile ;
84
80
}
85
81
86
82
// gets the image we use to test the network.
87
83
// This isn't a single class ImageNet image, so it won't do very well, but it will at least classify it as a dog or a cat.
88
- public static INDArray getTestImage () throws IOException {
84
+ private static INDArray getTestImage () throws IOException {
89
85
URL url = new URL ("https://github.com/tensorflow/models/blob/master/research/deeplab/g3doc/img/image2.jpg?raw=true" );
90
86
return new ImageLoader (358 , 500 , 3 ).asMatrix (url .openStream ());
91
87
}
@@ -99,7 +95,8 @@ public static INDArray getTestImage() throws IOException {
99
95
* @param height the height to resize to
100
96
* @param width the width to resize to
101
97
*/
102
- public static INDArray inceptionPreprocessing (INDArray img , int height , int width ){
98
+ @ SuppressWarnings ("SameParameterValue" )
99
+ private static INDArray inceptionPreprocessing (INDArray img , int height , int width ){
103
100
// add batch dimension
104
101
img = Nd4j .expandDims (img , 0 );
105
102
0 commit comments