Skip to content

Commit c09880d

Browse files
author
Niccolo Dal Santo
committed
Initial commit
0 parents  commit c09880d

File tree

87 files changed

+666
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+666
-0
lines changed

.gitattributes

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
*.fig binary
2+
*.mat binary
3+
*.mdl binary diff merge=mlAutoMerge
4+
*.mdlp binary
5+
*.mexa64 binary
6+
*.mexw64 binary
7+
*.mexmaci64 binary
8+
*.mlapp binary
9+
*.mldatx binary
10+
*.mlproj binary
11+
*.mlx binary
12+
*.p binary
13+
*.sfx binary
14+
*.sldd binary
15+
*.slreqx binary merge=mlAutoMerge
16+
*.slmx binary merge=mlAutoMerge
17+
*.sltx binary
18+
*.slxc binary
19+
*.slx binary merge=mlAutoMerge
20+
*.slxp binary
21+
22+
## Other common binary file types
23+
*.docx binary
24+
*.exe binary
25+
*.jpg binary
26+
*.pdf binary
27+
*.png binary
28+
*.xlsx binary

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# List of untracked files to ignore

ConvMixer.prj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<MATLABProject xmlns="http://www.mathworks.com/MATLABProjectFile" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" version="1.0"/>

README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# ConvMixer -- Patches are all you need?
2+
3+
This demo shows how to implement and train a ConvMixer architecture for image classification with MATLAB&reg;, as described in the paper "Patches are all you need?" https://openreview.net/forum?id=TVHS5Y4dNvM
4+
5+
The ConvMixer architecture employs a Patch Embedding representation of the input followed by repeated fully-convolutional blocks.
6+
7+
![ConvMixer Architecture](images/convMixer.png)
8+
9+
## How to get started
10+
11+
Start the project ConvMixer.prj to add to the path the relevant functions. There are examples in the `convmixer/examples` folder to get you started with training a ConvMixer for the digits dataset and the CIFAR-10 dataset [1].
12+
13+
The latter employs the ADAM algorithm with fixed weight decay regularization, as described in [2].
14+
15+
Training a ConvMixer for the CIFAR-10 architecture can be demanding in terms of computational resources: in the same `convmixer/examples` folder you can find a pretrained network. This model was trained on the CIFAR-10, available at https://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz
16+
17+
The source code for building the architecture is in the `convmixer/convmixer` directory.
18+
19+
## Requirements
20+
21+
- MATLAB&reg; R2021b or later
22+
- Deep Learning Toolbox&trade;
23+
24+
## License
25+
26+
The license is available in the license file within this repository.
27+
28+
Copyright 2021 The MathWorks, Inc.
29+
30+
[1] Krizhevsky, Alex. "Learning multiple layers of features from tiny images." (2009). https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf
31+
[2] Loshchilov, Ilya, and Frank Hutter. "Fixing weight decay regularization in ADAM." (2018). https://openreview.net/forum?id=rk6qdGgCZ

SECURITY.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Reporting Security Vulnerabilities
2+
3+
If you believe you have discovered a security vulnerability, please report it to
4+
[[email protected]](mailto:[email protected]). Please see
5+
[MathWorks Vulnerability Disclosure Policy for Security Researchers](https://www.mathworks.com/company/aboutus/policies_statements/vulnerability-disclosure-policy.html)
6+
for additional information.

convmixer/convMixerLayers.m

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
function lgraph = convMixerLayers(opts)
2+
% convMixerLayers Build ConvMixer architecture.
3+
%
4+
% lgraph = convMixerLayers() returns a LayerGraph object with a ConvMixer
5+
% architecture with default options as proposed in
6+
% https://openreview.net/forum?id=TVHS5Y4dNvM.
7+
%
8+
% lgraph = convMixerLayers(PARAM1=VAL1,PARAM2=VAL2,...) specifies optional
9+
% parameter name/value pairs for creating the layer graph:
10+
%
11+
% 'InputSize' - Size of the input images.
12+
%
13+
% 'NumClasses' - Number of classes the network predicts.
14+
%
15+
% 'KernelSize' - Size of the kernel for the depthwise
16+
% convolution.
17+
%
18+
% 'PatchSize' - Size of the pathes for the patch embedding
19+
% layer.
20+
%
21+
% 'Depth' - Number of repeated fully-convolutional
22+
% blocks.
23+
%
24+
% 'HiddenDimension' - Number of channels output by the patch
25+
% embedding.
26+
%
27+
% 'ConnectOutputLayer' - Determines whether to append a softmax and
28+
% classification output layers to the
29+
% returned LayerGraph object.
30+
%
31+
% Example:
32+
%
33+
% lgraph = convMixerLayers(InputSize=[28 28 1], Depth=5, NumClasses=10)
34+
35+
% Copyright 2021 The MathWorks, Inc.
36+
37+
arguments
38+
opts.InputSize = [227 227 3]
39+
opts.NumClasses = 1000
40+
opts.KernelSize = 9
41+
opts.PatchSize = 7
42+
opts.Depth = 20
43+
opts.HiddenDimension = 1536
44+
opts.ConnectOutputLayer logical = false
45+
end
46+
47+
input_size = opts.InputSize;
48+
num_classes = opts.NumClasses;
49+
50+
kernel_size = opts.KernelSize;
51+
patch_size = opts.PatchSize;
52+
depth = opts.Depth;
53+
hidden_dim = opts.HiddenDimension;
54+
connectOutputLayers = opts.ConnectOutputLayer;
55+
56+
% First layer is a "path embedding". Seems to be this:
57+
patchEmbedding = convolution2dLayer(patch_size, hidden_dim, ...
58+
Stride=patch_size, ...
59+
Name="patchEmbedding", ...
60+
WeightsInitializer="glorot");
61+
62+
% Make Layer Graph
63+
lgraph = layerGraph();
64+
65+
start = [
66+
imageInputLayer(input_size,Normalization="none")
67+
patchEmbedding
68+
geluLayer(Name="gelu_0")
69+
batchNormalizationLayer(Name="batchnorm_0")
70+
];
71+
lgraph = addLayers(lgraph,start);
72+
73+
for i = 1:depth
74+
convMixer = [
75+
groupedConvolution2dLayer(kernel_size,1,"channel-wise",Name="depthwiseConv_"+i,Padding="same",WeightsInitializer="glorot")
76+
geluLayer(Name="gelu_"+(2*i-1))
77+
batchNormalizationLayer(Name="batchnorm_"+(2*i-1))
78+
additionLayer(2,Name="addition_"+i)
79+
convolution2dLayer([1 1],hidden_dim,Name="pointwiseConv_"+i,WeightsInitializer="glorot")
80+
geluLayer(Name="gelu_"+2*i)
81+
batchNormalizationLayer(Name="batchnorm_"+2*i)
82+
];
83+
lgraph = addLayers(lgraph,convMixer);
84+
lgraph = connectLayers(lgraph,"batchnorm_"+2*(i-1),"depthwiseConv_"+i);
85+
lgraph = connectLayers(lgraph,"batchnorm_"+2*(i-1),"addition_"+i+"/in2");
86+
end
87+
88+
gapFc = [
89+
globalAveragePooling2dLayer(Name="GAP")
90+
fullyConnectedLayer(num_classes)
91+
];
92+
lgraph = addLayers(lgraph,gapFc);
93+
lgraph = connectLayers(lgraph,"batchnorm_"+2*depth,"GAP");
94+
95+
if connectOutputLayers
96+
lgraph = addLayers(lgraph, softmaxLayer('Name','softmax'));
97+
lgraph = addLayers(lgraph, classificationLayer('Name','classification'));
98+
lgraph = connectLayers(lgraph,'fc','softmax');
99+
lgraph = connectLayers(lgraph,'softmax','classification');
100+
end
101+
end

convmixer/geluLayer.m

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
classdef geluLayer < nnet.layer.Layer
2+
% geluLayer GELU layer.
3+
%
4+
% gLayer = geluLayer() returns a geluLayer object.
5+
%
6+
% gLayer = geluLayer(PARAM1=VAL1,PARAM2=VAL2,...) specifies optional
7+
% parameter name/value pairs for creating the layer graph:
8+
%
9+
% 'Mode' - Size of the input images. Options are 'fast
10+
%
11+
% 'Name' - Name of the layer.
12+
%
13+
% See https://paperswithcode.com/method/gelu for details.
14+
%
15+
% Example:
16+
%
17+
% gLayer = geluLayer()
18+
19+
% Copyright 2021 The MathWorks, Inc.
20+
21+
properties(SetAccess='private')
22+
Mode
23+
end
24+
25+
methods
26+
function obj = geluLayer(opts)
27+
arguments
28+
opts.Mode string {mustBeMember(opts.Mode,["fast", "exact"])} = "fast";
29+
opts.Name string {mustBeText} = "gelu";
30+
end
31+
obj.Name = opts.Name;
32+
obj.Mode = opts.Mode;
33+
end
34+
35+
function y = predict(obj,x)
36+
switch obj.Mode
37+
case "exact"
38+
y = x/2.*(1+erf(x/sqrt(2)));
39+
case "fast"
40+
y = x/2.*(1+tanh(sqrt(2/pi)*(x+0.044715*x.^3)));
41+
otherwise
42+
error Unknown
43+
end
44+
end
45+
end
46+
end

data/downloadCIFARData.m

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
function downloadCIFARData(destination)
2+
% downloadCIFARData Download CIFAR-10 dataset.
3+
4+
% Copyright 2021 The MathWorks, Inc.
5+
6+
url = 'https://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz';
7+
8+
unpackedData = fullfile(destination,'cifar-10-batches-mat');
9+
if ~exist(unpackedData,'dir')
10+
fprintf('Downloading CIFAR-10 dataset (175 MB). This can take a while...');
11+
untar(url,destination);
12+
fprintf('done.\n\n');
13+
end
14+
15+
end

data/loadCIFARData.m

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
function [XTrain,YTrain,XTest,YTest] = loadCIFARData(location)
2+
% loadCIFARData Load CIFAR-10 dataset and split it to training and
3+
% validation sets.
4+
5+
% Copyright 2021 The MathWorks, Inc.
6+
7+
location = fullfile(location,'cifar-10-batches-mat');
8+
9+
[XTrain1,YTrain1] = loadBatchAsFourDimensionalArray(location,'data_batch_1.mat');
10+
[XTrain2,YTrain2] = loadBatchAsFourDimensionalArray(location,'data_batch_2.mat');
11+
[XTrain3,YTrain3] = loadBatchAsFourDimensionalArray(location,'data_batch_3.mat');
12+
[XTrain4,YTrain4] = loadBatchAsFourDimensionalArray(location,'data_batch_4.mat');
13+
[XTrain5,YTrain5] = loadBatchAsFourDimensionalArray(location,'data_batch_5.mat');
14+
XTrain = cat(4,XTrain1,XTrain2,XTrain3,XTrain4,XTrain5);
15+
YTrain = [YTrain1;YTrain2;YTrain3;YTrain4;YTrain5];
16+
17+
[XTest,YTest] = loadBatchAsFourDimensionalArray(location,'test_batch.mat');
18+
end
19+
20+
function [XBatch,YBatch] = loadBatchAsFourDimensionalArray(location,batchFileName)
21+
s = load(fullfile(location,batchFileName));
22+
XBatch = s.data';
23+
XBatch = reshape(XBatch,32,32,3,[]);
24+
XBatch = permute(XBatch,[2 1 3 4]);
25+
YBatch = convertLabelsToCategorical(location,s.labels);
26+
end
27+
28+
function categoricalLabels = convertLabelsToCategorical(location,integerLabels)
29+
s = load(fullfile(location,'batches.meta.mat'));
30+
categoricalLabels = categorical(integerLabels,0:9,s.label_names);
31+
end
32+
7.72 MB
Binary file not shown.

0 commit comments

Comments
 (0)