|
1 | | -%% Train ConvMixer Network with Digit Dataset |
2 | | -% |
3 | | -% This example shows how to build a ConvMixer network architecture and |
4 | | -% train it with the digit dataset. |
5 | | - |
6 | | -% Copyright 2021 The MathWorks, Inc. |
7 | | - |
8 | | -[XTrain, TTrain] = digitTrain4DArrayData; |
9 | | -[XTest, TTest] = digitTest4DArrayData; |
10 | | - |
11 | | -% Build a ConvMixer layergraph |
12 | | -lg = convMixerLayers( ... |
13 | | - InputSize=[28 28 1], ... |
14 | | - NumClasses=10, ... |
15 | | - HiddenDimension=256, ... |
16 | | - Depth=5, ... |
17 | | - ConnectOutputLayer=true, ... |
18 | | - PatchSize=1); |
19 | | - |
20 | | -% Instantiate options |
21 | | -options = trainingOptions( "sgdm", ... |
22 | | - MaxEpochs=10, ... |
23 | | - Shuffle="every-epoch", ... |
24 | | - ValidationData={XTest, TTest}, ... |
25 | | - Plots="training-progress" ); |
26 | | - |
27 | | -% Train ConvMixer |
28 | | -net = trainNetwork(XTrain, TTrain, lg, options); |
| 1 | +%% Train ConvMixer Network with Digit Dataset |
| 2 | +% |
| 3 | +% This example shows how to build a ConvMixer network architecture and |
| 4 | +% train it with the digit dataset. |
| 5 | + |
| 6 | +% Copyright 2021 The MathWorks, Inc. |
| 7 | + |
| 8 | +[XTrain, TTrain] = digitTrain4DArrayData; |
| 9 | +[XTest, TTest] = digitTest4DArrayData; |
| 10 | + |
| 11 | +% Build a ConvMixer layergraph |
| 12 | +lg = convMixerLayers( ... |
| 13 | + InputSize=[28 28 1], ... |
| 14 | + NumClasses=10, ... |
| 15 | + HiddenDimension=256, ... |
| 16 | + Depth=5, ... |
| 17 | + ConnectOutputLayer=true, ... |
| 18 | + PatchSize=1); |
| 19 | + |
| 20 | +% Instantiate options |
| 21 | +options = trainingOptions( "sgdm", ... |
| 22 | + MaxEpochs=10, ... |
| 23 | + Shuffle="every-epoch", ... |
| 24 | + ValidationData={XTest, TTest}, ... |
| 25 | + Plots="training-progress" ); |
| 26 | + |
| 27 | +% Train ConvMixer |
| 28 | +net = trainNetwork(XTrain, TTrain, lg, options); |
0 commit comments