|
49 | 49 | import org.deeplearning4j.util.ModelSerializer; |
50 | 50 | import org.deeplearning4j.zoo.model.TinyYOLO; |
51 | 51 | import org.nd4j.linalg.activations.Activation; |
| 52 | +import org.nd4j.linalg.api.memory.enums.DebugMode; |
52 | 53 | import org.nd4j.linalg.api.ndarray.INDArray; |
53 | 54 | import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; |
54 | 55 | import org.nd4j.linalg.factory.Nd4j; |
55 | 56 | import org.nd4j.linalg.learning.config.Adam; |
| 57 | +import org.nd4j.linalg.profiler.ProfilerConfig; |
56 | 58 | import org.slf4j.Logger; |
57 | 59 | import org.slf4j.LoggerFactory; |
58 | 60 |
|
@@ -122,18 +124,17 @@ public static void main(String[] args) throws java.lang.Exception { |
122 | 124 | File trainDir = fetcher.getDataSetPath(DataSetType.TRAIN); |
123 | 125 | File testDir = fetcher.getDataSetPath(DataSetType.TEST); |
124 | 126 |
|
125 | | - |
126 | 127 | log.info("Load data..."); |
127 | 128 |
|
128 | 129 | FileSplit trainData = new FileSplit(trainDir, NativeImageLoader.ALLOWED_FORMATS, rng); |
129 | 130 | FileSplit testData = new FileSplit(testDir, NativeImageLoader.ALLOWED_FORMATS, rng); |
130 | 131 |
|
131 | 132 | ObjectDetectionRecordReader recordReaderTrain = new ObjectDetectionRecordReader(height, width, nChannels, |
132 | | - gridHeight, gridWidth, new SvhnLabelProvider(trainDir)); |
| 133 | + gridHeight, gridWidth, new SvhnLabelProvider(trainDir)); |
133 | 134 | recordReaderTrain.initialize(trainData); |
134 | 135 |
|
135 | 136 | ObjectDetectionRecordReader recordReaderTest = new ObjectDetectionRecordReader(height, width, nChannels, |
136 | | - gridHeight, gridWidth, new SvhnLabelProvider(testDir)); |
| 137 | + gridHeight, gridWidth, new SvhnLabelProvider(testDir)); |
137 | 138 | recordReaderTest.initialize(testData); |
138 | 139 |
|
139 | 140 | // ObjectDetectionRecordReader performs regression, so we need to specify it here |
@@ -210,7 +211,7 @@ public static void main(String[] args) throws java.lang.Exception { |
210 | 211 | CanvasFrame frame = new CanvasFrame("HouseNumberDetection"); |
211 | 212 | OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat(); |
212 | 213 | org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer yout = |
213 | | - (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer)model.getOutputLayer(0); |
| 214 | + (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer)model.getOutputLayer(0); |
214 | 215 | List<String> labels = train.getLabels(); |
215 | 216 | test.setCollectMetaData(true); |
216 | 217 | Scalar[] colormap = {RED,BLUE,GREEN,CYAN,YELLOW,MAGENTA,ORANGE,PINK,LIGHTBLUE,VIOLET}; |
|
0 commit comments