-
-
Notifications
You must be signed in to change notification settings - Fork 21
Expand file tree
/
Copy pathProgram.cs
More file actions
213 lines (190 loc) · 8.58 KB
/
Program.cs
File metadata and controls
213 lines (190 loc) · 8.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
using System.Runtime.CompilerServices;
using System;
using System.IO;
using System.Linq;
using CNTK;
using CNTKUtil;
using XPlot.Plotly;
namespace CatsAndDogs
{
/// <summary>
/// The main program class.
/// </summary>
class Program
{
// filenames for data set
private static string trainMapPath = Path.Combine(Environment.CurrentDirectory, "train_map.txt");
private static string testMapPath = Path.Combine(Environment.CurrentDirectory, "test_map.txt");
// total number of images in the training set
private const int trainingSetSize = 1600; // 80% of 2000 images
private const int testingSetSize = 400; // 20% of 2000 images
/// <summary>
/// Create the mapping files for features and labels
/// </summary>
static void CreateMappingFiles()
{
// get both classes of images
var class0Images = Directory.GetFiles(Path.Combine(Environment.CurrentDirectory, "cat"));
var class1Images = Directory.GetFiles(Path.Combine(Environment.CurrentDirectory, "dog"));
// generate train and test mapping files
var mappingFiles = new string[] { trainMapPath, testMapPath };
var partitionSizes = new int[] { trainingSetSize, testingSetSize };
var imageIndex = 0;
for (int mapIndex = 0; mapIndex < mappingFiles.Length; mapIndex++)
{
var filePath = mappingFiles[mapIndex];
using (var dstFile = new StreamWriter(filePath))
{
for (var i = 0; i < partitionSizes[mapIndex]; i++)
{
var class0Path = Path.Combine("cat", class0Images[imageIndex]);
var class1Path = Path.Combine("dog", class1Images[imageIndex]);
dstFile.WriteLine($"{class0Path}\t0");
dstFile.WriteLine($"{class1Path}\t1");
imageIndex++;
}
}
Console.WriteLine($" Created file: {filePath}");
}
Console.WriteLine();
}
// image details
private const int imageWidth = 150;
private const int imageHeight = 150;
private const int numChannels = 3;
/// <summary>
/// The main program entry point.
/// </summary>
/// <param name="args">The command line arguments.</param>
static void Main(string[] args)
{
// create the mapping files
Console.WriteLine("Creating mapping files...");
CreateMappingFiles();
// download the VGG16 network
Console.WriteLine("Downloading VGG16...");
if (!DataUtil.VGG16.IsDownloaded)
{
DataUtil.VGG16.Download();
}
// get a training and testing image readers
var trainingReader = DataUtil.GetImageReader(trainMapPath, imageWidth, imageHeight, numChannels, 2, randomizeData: true, augmentData: true);
var testingReader = DataUtil.GetImageReader(testMapPath, imageWidth, imageHeight, numChannels, 2, randomizeData: false, augmentData: false);
// build features and labels
var features = NetUtil.Var(new int[] { imageHeight, imageWidth, numChannels }, DataType.Float);
var labels = NetUtil.Var(new int[] { 2 }, DataType.Float);
// build the network
var network = features
.MultiplyBy<float>(1.0f / 255.0f) // divide all pixels by 255
.VGG16(allowBlock5Finetuning: false)
.Dense(256, CNTKLib.ReLU)
.Dropout(0.5)
.Dense(2, CNTKLib.Softmax)
.ToNetwork();
Console.WriteLine("Model architecture:");
Console.WriteLine(network.ToSummary());
// set up the loss function and the classification error function
var lossFunction = CNTKLib.CrossEntropyWithSoftmax(network.Output, labels);
var errorFunction = CNTKLib.ClassificationError(network.Output, labels);
// use the Adam learning algorithm
var learner = network.GetAdamLearner(
learningRateSchedule: (0.0001, 1),
momentumSchedule: (0.99, 1));
// set up a trainer and an evaluator
var trainer = network.GetTrainer(learner, lossFunction, errorFunction);
var evaluator = network.GetEvaluator(errorFunction);
// train the model
Console.WriteLine("Epoch\tTrain\tTrain\tTest");
Console.WriteLine("\tLoss\tError\tError");
Console.WriteLine("-----------------------------");
var maxEpochs = 25;
var batchSize = 16;
var loss = new double[maxEpochs];
var trainingError = new double[maxEpochs];
var testingError = new double[maxEpochs];
var batchCount = 0;
for (int epoch = 0; epoch < maxEpochs; epoch++)
{
// train one epoch on batches
loss[epoch] = 0.0;
trainingError[epoch] = 0.0;
var sampleCount = 0;
while (sampleCount < 2 * trainingSetSize)
{
// get the current batch for training
var batch = trainingReader.GetBatch(batchSize);
var featuresBatch = batch[trainingReader.StreamInfo("features")];
var labelsBatch = batch[trainingReader.StreamInfo("labels")];
// train the model on the batch
var result = trainer.TrainBatch(
new[] {
(features, featuresBatch),
(labels, labelsBatch)
}
);
loss[epoch] += result.Loss;
trainingError[epoch] += result.Evaluation;
sampleCount += (int)featuresBatch.numberOfSamples;
batchCount++;
}
// show results
loss[epoch] /= batchCount;
trainingError[epoch] /= batchCount;
Console.Write($"{epoch}\t{loss[epoch]:F3}\t{trainingError[epoch]:F3}\t");
// test one epoch on batches
testingError[epoch] = 0.0;
batchCount = 0;
sampleCount = 0;
while (sampleCount < 2 * testingSetSize)
{
// get the current batch for testing
var batch = testingReader.GetBatch(batchSize);
var featuresBatch = batch[testingReader.StreamInfo("features")];
var labelsBatch = batch[testingReader.StreamInfo("labels")];
// test the model on the batch
testingError[epoch] += evaluator.TestBatch(
new[] {
(features, featuresBatch),
(labels, labelsBatch)
}
);
sampleCount += (int)featuresBatch.numberOfSamples;
batchCount++;
}
// show results
testingError[epoch] /= batchCount;
Console.WriteLine($"{testingError[epoch]:F3}");
}
// show final results
var finalError = testingError[maxEpochs-1];
Console.WriteLine();
Console.WriteLine($"Final test error: {finalError:0.00}");
Console.WriteLine($"Final test accuracy: {1 - finalError:0.00}");
// plot the error graph
var chart = Chart.Plot(
new []
{
new Graph.Scatter()
{
x = Enumerable.Range(0, maxEpochs).ToArray(),
y = trainingError.Select(v => 1 - v),
name = "training",
mode = "lines+markers"
},
new Graph.Scatter()
{
x = Enumerable.Range(0, maxEpochs).ToArray(),
y = testingError.Select(v => 1 - v),
name = "testing",
mode = "lines+markers"
}
}
);
chart.WithXTitle("Epoch");
chart.WithYTitle("Accuracy");
chart.WithTitle("Cats and Dogs Training");
// save chart
File.WriteAllText("chart.html", chart.GetHtml());
}
}
}