Skip to content

Commit 998b3d0

Browse files
author
KentaItakura
committed
first commit
0 parents  commit 998b3d0

File tree

13 files changed

+363
-0
lines changed

13 files changed

+363
-0
lines changed

README.md

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
# Image Classification using Convolutional Neural Network with Multi-Input
2+
3+
**[English]**
4+
This demo shows how to implement convolutional neural network (CNN) for image classification with multi-input using `custom loop` method. As an example, a dataset of hand-written digits called MNIST was divided into the upper half and down half as shown below and the upper and down part were fed into the multi input CNN.
5+
**[Japanese]**
6+
2種類の画像を入力できる畳み込みニューラルネットワークのデモです。2つの入力層があって、例えば、入力層Aには、動物の顔の画像を入力し、入力層Bには、その動物の足の画像を入力する、などです。2019bバージョンからカスタムループと呼ばれる方法が可能になり、深層学習のより詳細なカスタマイズが可能となりました。簡単にためせるように、手書き数字の上半分と下半分をそれぞれ別の入力層からインプットし、畳み込みなどを行った後に得られた特徴量を結合させ、さらに全結合層などで計算を進めています。よりこの例に適切なデータや課題などがございましたら教えていただけると幸いです。まだまだ作りこみ不足なところもあり、今後も更新していければよいと考えています。
7+
8+
![image_0.png](README_images/image_0.png)
9+
10+
![image_1.png](README_images/image_1.png)
11+
12+
![image_2.png](README_images/image_2.png)
13+
14+
The figure above shows the classification accuracy with the multi-input CNN. The top and down part of the digits were fed into the multi-input CNN, the accuracy was over 96 %. If only the top or down part were used for the CNN, the accuracy was significantly lower than that with multi-input.
15+
16+
17+
18+
![image_3.png](README_images/image_3.png)
19+
20+
Note that this figure is cited from ref [1]. The paper was talking about video classification, not still image classification. However, the fusion model described above is very infomative. In my understanding, this demo is similar to early late fusion, (but please do confirm). Other types of fusion may be implemented in my future work. In ref [2], they proposed a deep learning model called TM-CNN for multi-lane traffic speed prediction, which would be related to this demo.
21+
22+
[1] Karpathy, A., Toderici, G., Shetty, S., Leung, T., Sukthankar, R., \& Fei-Fei, L. (2014). Large-scale video classification with convolutional neural networks. In *Proceedings of the IEEE conference on Computer Vision and Pattern Recognition* (pp. 1725-1732).
23+
24+
[2] Ke, R., Li, W., Cui, Z., \& Wang, Y. (2019). Two-stream multi-channel convolutional neural network (TM-CNN) for multi-lane traffic speed prediction considering traffic volume impact. *arXiv preprint arXiv:1903.01678*.
25+
26+
# Data preparation
27+
28+
This script saves the hand-digit dataset into sub-folders. Use prepareDigitDataset function to create `upperHalf` and `bottomHalf` folders.
29+
30+
```matlab:Code
31+
clear;clc;close all
32+
if exist('bottomHalf')~=7 % the data is already prepared. This section is skipped.
33+
disp('Preparing demo dataset for this script')
34+
prepareDigitDataset
35+
end
36+
```
37+
38+
# Store the images into `imagedatastore`
39+
40+
```matlab:Code
41+
inputSize=[14 28];
42+
firstFolderName='upperHalf';
43+
secondFolderName='bottomHalf';
44+
imdsUpper = imageDatastore(strcat(firstFolderName,filesep), 'IncludeSubfolders',true, 'LabelSource','foldernames');
45+
imdsBottom = imageDatastore(strcat(secondFolderName,filesep), 'IncludeSubfolders',true, 'LabelSource','foldernames');
46+
augmenter = imageDataAugmenter('RandXReflection',false);
47+
augimdsUpper = augmentedImageDatastore(inputSize,imdsUpper,'DataAugmentation',augmenter);
48+
augimdsBottom = augmentedImageDatastore(inputSize,imdsBottom,'DataAugmentation',augmenter);
49+
numAll=numel(imdsBottom.Files);
50+
```
51+
52+
# Dividing into training, validataion and test dataset
53+
54+
```matlab:Code
55+
% The ratio is specified here
56+
TrainRatio=0.8;
57+
ValidRatio=0.1;
58+
TestRatio=1-TrainRatio-ValidRatio;
59+
```
60+
61+
Use the helper function `partitionData`. It separate the dataset with the ratio as defined.
62+
63+
```matlab:Code
64+
[XTrainUpper,XTrainBottom,XValidUpper,XValidBottom,XTestUpper,XTestBottom,YTrain,YValid,YTest]=partitionData(augimdsUpper,augimdsBottom,TrainRatio,ValidRatio,numAll,imdsUpper.Labels);
65+
classes = categories(YTrain); % retrieve the class names
66+
numClasses = numel(classes); % the number of classes
67+
```
68+
69+
# Define convolutional neural network model
70+
71+
![image_4.png](README_images/image_4.png)
72+
73+
```matlab:Code
74+
numHiddenDimension=20; % speficy the dimension of the hidden layer
75+
layers = createSimpleLayer(XTrainUpper,numHiddenDimension);
76+
layers2 = createSimpleLayer(XTrainBottom,numHiddenDimension);
77+
```
78+
79+
When the two layers are merged, the same name of the layers cannot be used. Use renameLayer function to rename the layer name in `layers2`
80+
81+
```matlab:Code
82+
layers2=renameLayer(layers2,'_2');
83+
layersAdd=[fullyConnectedLayer(20,'Name','fcAdd1')
84+
fullyConnectedLayer(numClasses,'Name','fcAdd2')];
85+
layersRemoved=[layers(1:end);concatenationLayer(1,2,'Name','cat');layersAdd];
86+
lgraphAggregated = addLayers(layerGraph(layersRemoved),layers2(1:end));
87+
lgraphAggregated = connectLayers(lgraphAggregated,'fc_2','cat/in2');
88+
```
89+
90+
Covert into deep learning network for custom training loops using `dlnetwork`
91+
92+
```matlab:Code
93+
dlnet = dlnetwork(lgraphAggregated); % A dlnetwork object enables support for custom training loops using automatic differentiation
94+
```
95+
96+
# Specify training options
97+
98+
```matlab:Code
99+
miniBatchSize = 16; % mini batch size. When you run out of memory, decrease this value like 4
100+
numEpochs = 30; % max epoch
101+
numObservations = numel(YTrain); % the number of training data
102+
numIterationsPerEpoch = floor(numObservations./miniBatchSize); % number of iterations per epoch
103+
executionEnvironment = "gpu"; % Set "gpu" when you use gpu
104+
```
105+
106+
Initial setting for `Adam` optimizer
107+
108+
```matlab:Code
109+
averageGrad = [];
110+
averageSqGrad = [];
111+
iteration = 1; % initialize iteration
112+
```
113+
114+
# Create `animated line`
115+
116+
`animatedline` creates an animated line that has no data and adds it to the current axes. Create an animation by adding points to the line in a loop using the `addpoints` function.
117+
118+
```matlab:Code
119+
plots = "training-progress";
120+
if plots == "training-progress"
121+
f1=figure;
122+
lineLossTrain = animatedline('Color','r');
123+
xlabel("Total Iterations")
124+
ylabel("Loss");lineLossValid = animatedline('Color','b');
125+
xlabel("Total Iterations");ylabel("LossValid")
126+
end
127+
```
128+
129+
# Prepare the validation data
130+
131+
The validation data is called during training to check the CNN performance.
132+
133+
```matlab:Code
134+
YValidPlot=zeros(numClasses,numel(YValid),'single');
135+
for c = 1:numClasses
136+
YValidPlot(c,YValid==classes(c)) = 1;
137+
end
138+
% Convert mini-batch of data to a dlarray.
139+
dlXValidUpper=dlarray(single(XValidUpper),'SSCB');
140+
dlXValidBottom=dlarray(single(XValidBottom),'SSCB');
141+
142+
% If training on a GPU, then convert data to a gpuArray.
143+
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
144+
dlXValidUpper = gpuArray(dlXValidUpper);
145+
dlXValidBottom = gpuArray(dlXValidBottom);
146+
end
147+
```
148+
149+
# Train network in custom training loop
150+
151+
```matlab:Code
152+
for epoch = 1:numEpochs
153+
% Shuffle data.
154+
idx = randperm(numel(YTrain));
155+
XTrainUpper = XTrainUpper(:,:,:,idx);
156+
XTrainBottom = XTrainBottom(:,:,:,idx);
157+
YTrain=YTrain(idx);
158+
159+
for i = 1:numIterationsPerEpoch
160+
161+
% Read mini-batch of data and convert the labels to dummy
162+
% variables.
163+
idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
164+
XUpper = XTrainUpper(:,:,:,idx);
165+
XBottom = XTrainBottom(:,:,:,idx);
166+
167+
Y = zeros(numClasses, miniBatchSize, 'single');
168+
for c = 1:numClasses
169+
Y(c,YTrain(idx)==classes(c)) = 1;
170+
end
171+
172+
% Convert mini-batch of data to a dlarray.
173+
dlXUpper = dlarray(single(XUpper),'SSCB');
174+
dlXBottom = dlarray(single(XBottom),'SSCB');
175+
176+
% If training on a GPU, then convert data to a gpuArray.
177+
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
178+
dlXUpper = gpuArray(dlXUpper);
179+
dlXBottom = gpuArray(dlXBottom);
180+
end
181+
182+
% Evaluate the model gradients and loss using dlfeval and the
183+
% modelGradients helper function.
184+
[grad,loss] = dlfeval(@modelGradientsMulti,dlnet,dlXUpper,dlXBottom,Y);
185+
lossValid = modelLossMulti(dlnet,dlXValidUpper,dlXValidBottom,YValidPlot);
186+
% Update the network parameters using the Adam optimizer.
187+
[dlnet,averageGrad,averageSqGrad] = adamupdate(dlnet,grad,averageGrad,averageSqGrad,iteration,0.0005);
188+
189+
% Display the training progress.
190+
if plots == "training-progress"
191+
addpoints(lineLossTrain,iteration,double(gather(extractdata(loss))))
192+
title("Loss During Training: Epoch - " + epoch + "; Iteration - " + i)
193+
addpoints(lineLossValid,iteration,double(gather(extractdata(lossValid))))
194+
title("Loss During Validation: Epoch - " + epoch + "; Iteration - " + i)
195+
drawnow
196+
end
197+
198+
% Increment the iteration counter.
199+
iteration = iteration + 1;
200+
end
201+
end
202+
```
203+
204+
![figure_1.png](README_images/figure_1.png)
205+
206+
# Compute classification accuracy
207+
208+
```matlab:Code
209+
dlXTestUpper = dlarray(single(XTestUpper),'SSCB');
210+
dlXTestBottom = dlarray(single(XTestBottom),'SSCB');
211+
```
212+
213+
Convert the test data into `gpuArray` to accelarate with GPU
214+
215+
```matlab:Code
216+
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
217+
dlXTestUpper = gpuArray(dlXTestUpper);
218+
dlXTestBottom = gpuArray(dlXTestBottom);
219+
end
220+
```
221+
222+
Two similar function are available in MATLAB for calculating the output of deep learning network.
223+
224+
`predict`: Compute deep learning network output for inference
225+
226+
`forward`: Compute deep learning network output for training
227+
228+
The difference is either for training or testing. In the training phase, some techniques like batch normalization and dropout are employed while they are not used in testing.
229+
230+
```matlab:Code
231+
dlYPred = predict(dlnet,dlXTestUpper,dlXTestBottom); % use predict for testing
232+
[~,idx] = max(extractdata(dlYPred),[],1); % extract the class with highest score
233+
YPred = classes(idx);
234+
```
235+
236+
Calculate the overall accuracy
237+
238+
```matlab:Code
239+
accuracy = mean(YPred==YTest)
240+
```
241+
242+
```text:Output
243+
accuracy = 0.9720
244+
```
245+
246+
Display confusion matrix
247+
248+
```matlab:Code
249+
confusionchart(YTest,categorical(cellstr(YPred)))
250+
```
251+
252+
![figure_2.png](README_images/figure_2.png)
253+
254+
# Helper functions
255+
256+
```matlab:Code
257+
function layers=createSimpleLayer(XTrainData_4D,numHiddenDimension)
258+
layers = [
259+
imageInputLayer([14 28 3],"Name","imageinput","Mean",mean(XTrainData_4D,4))
260+
convolution2dLayer([3 3],8,"Name","conv_1","Padding","same")
261+
reluLayer("Name","relu_1")
262+
maxPooling2dLayer([2 2],"Name","maxpool_1","Stride",[2 2])
263+
convolution2dLayer([3 3],16,"Name","conv_2","Padding","same")
264+
reluLayer("Name","relu_2")
265+
maxPooling2dLayer([2 2],"Name","maxpool_2","Stride",[2 2])
266+
convolution2dLayer([3 3],32,"Name","conv_3","Padding","same")
267+
reluLayer("Name","relu_3")
268+
fullyConnectedLayer(numHiddenDimension,"Name","fc")];
269+
end
270+
271+
function [gradients,loss] = modelGradientsMulti(dlnet,dlXupper,dlXBottom,Y)
272+
273+
dlYPred = forward(dlnet,dlXupper,dlXBottom);
274+
dlYPred = softmax(dlYPred);
275+
276+
loss = crossentropy(dlYPred,Y);
277+
gradients = dlgradient(loss,dlnet.Learnables);
278+
279+
end
280+
281+
function layers=renameLayer(layers,char)
282+
for i=1:numel(layers)
283+
layers(i).Name=[layers(i).Name,char];
284+
end
285+
end
286+
287+
function loss = modelLossMulti(dlnet,dlXUpper,dlXBottom,Y)
288+
dlYPred = forward(dlnet,dlXUpper,dlXBottom);
289+
dlYPred = softmax(dlYPred);
290+
loss = crossentropy(dlYPred,Y);
291+
end
292+
```

README_images/figure_1.png

31.6 KB
Loading

README_images/figure_2.png

37.3 KB
Loading

README_images/image_0.png

370 KB
Loading

README_images/image_1.png

137 KB
Loading

README_images/image_2.png

10.9 KB
Loading

README_images/image_3.png

122 KB
Loading

README_images/image_4.png

31.9 KB
Loading

changeLog.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
v2: The dlnet1 to 3 were merged into a single dlnet. Firstly uploaded into github as well as MATLAB file exchange
2+
v1: Three dlnet1 to 3 were prepared then trained using forward functions

license.txt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
Copyright (c) 2021, Kenta Itakura
2+
All rights reserved.
3+
4+
Redistribution and use in source and binary forms, with or without
5+
modification, are permitted provided that the following conditions are met:
6+
7+
* Redistributions of source code must retain the above copyright notice, this
8+
list of conditions and the following disclaimer.
9+
10+
* Redistributions in binary form must reproduce the above copyright notice,
11+
this list of conditions and the following disclaimer in the documentation
12+
and/or other materials provided with the distribution
13+
* Neither the name of nor the names of its
14+
contributors may be used to endorse or promote products derived from this
15+
software without specific prior written permission.
16+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
20+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

0 commit comments

Comments
 (0)