Skip to content

Commit 3096796

Browse files
committed
Work on the convolutional layers
1 parent b13f8ec commit 3096796

File tree

8 files changed

+70
-43
lines changed

8 files changed

+70
-43
lines changed

src/main/java/net/echo/brain4j/layer/Layer.java

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,12 @@ public abstract class Layer<I, O> {
3535
protected final Activations activation;
3636
protected final Activation function;
3737
protected Layer<?, ?> nextLayer;
38+
protected int id;
3839

3940
public Layer(int input, Activations activation) {
40-
Parameters.TOTAL_LAYERS++;
4141
Stream.generate(Neuron::new).limit(input).forEach(neurons::add);
4242

43+
this.id = Parameters.TOTAL_LAYERS++;
4344
this.activation = activation;
4445
this.function = activation.getFunction();
4546
}
@@ -82,32 +83,32 @@ public void updateWeights(Vector[] synapseMatrixLayer) {
8283
throw new UnsupportedOperationException("Not implemented for this class.");
8384
}
8485

85-
public void applyFunction(StatesCache cacheHolder, Layer<?, ?> previous) {
86-
function.apply(cacheHolder, neurons);
86+
public void applyFunction(StatesCache cache, Layer<?, ?> previous) {
87+
function.apply(cache, neurons);
8788
}
8889

89-
public void setInput(StatesCache cacheHolder, Vector input) {
90+
public void setInput(StatesCache cache, Vector input) {
9091
Preconditions.checkState(input.size() == neurons.size(), "Input size does not match!" +
9192
" (Input != Expected) " + input.size() + " != " + neurons.size());
9293

9394
for (int i = 0; i < input.size(); i++) {
94-
neurons.get(i).setValue(cacheHolder, input.get(i));
95+
neurons.get(i).setValue(cache, input.get(i));
9596
}
9697
}
9798

98-
public void propagate(StatesCache cacheHolder, Layer<?, ?> previous, Updater updater, Optimizer optimizer) {
99+
public void propagate(StatesCache cache, Layer<?, ?> previous, Updater updater, Optimizer optimizer) {
99100
int nextLayerSize = nextLayer.getNeurons().size();
100101

101102
for (int i = 0; i < neurons.size(); i++) {
102103
Neuron neuron = neurons.get(i);
103104

104-
double value = neuron.getValue(cacheHolder);
105+
double value = neuron.getValue(cache);
105106
double derivative = activation.getFunction().getDerivative(value);
106107

107108
for (int j = 0; j < nextLayerSize; j++) {
108109
Synapse synapse = synapses.get(i * nextLayerSize + j);
109110

110-
float weightChange = calculateGradient(cacheHolder, synapse, derivative);
111+
float weightChange = calculateGradient(cache, synapse, derivative);
111112
updater.acknowledgeChange(synapse, weightChange);
112113
}
113114
}
@@ -148,4 +149,8 @@ public int getTotalParams() {
148149
public int getTotalNeurons() {
149150
return neurons.size();
150151
}
152+
153+
public int getId() {
154+
return id;
155+
}
151156
}

src/main/java/net/echo/brain4j/layer/impl/convolution/ConvLayer.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import net.echo.brain4j.activation.Activations;
55
import net.echo.brain4j.convolution.Kernel;
66
import net.echo.brain4j.layer.Layer;
7+
import net.echo.brain4j.structure.cache.Parameters;
78
import net.echo.brain4j.structure.cache.StatesCache;
89
import net.echo.brain4j.training.optimizers.Optimizer;
910
import net.echo.brain4j.training.updater.Updater;
@@ -63,6 +64,7 @@ public ConvLayer(int filters, int kernelWidth, int kernelHeight, int stride, Act
6364
*/
6465
public ConvLayer(int filters, int kernelWidth, int kernelHeight, int stride, int padding, Activations activation) {
6566
super(0, activation);
67+
this.id = Parameters.TOTAL_CONV_LAYER++;
6668
this.filters = filters;
6769
this.kernelWidth = kernelWidth;
6870
this.kernelHeight = kernelHeight;
@@ -115,7 +117,7 @@ public Kernel forward(StatesCache cache, Layer<?, ?> lastLayer, Kernel input) {
115117
}
116118

117119
@Override
118-
public void propagate(StatesCache cacheHolder, Layer<?, ?> nextLayer, Updater updater, Optimizer optimizer) {
120+
public void propagate(StatesCache cache, Layer<?, ?> nextLayer, Updater updater, Optimizer optimizer) {
119121
throw new UnsupportedOperationException("Not implemented yet.");
120122
}
121123

src/main/java/net/echo/brain4j/layer/impl/convolution/FlattenLayer.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import net.echo.brain4j.layer.Layer;
77
import net.echo.brain4j.layer.impl.DenseLayer;
88
import net.echo.brain4j.structure.cache.StatesCache;
9+
import net.echo.brain4j.training.optimizers.Optimizer;
10+
import net.echo.brain4j.training.updater.Updater;
911
import net.echo.brain4j.utils.Vector;
1012

1113
public class FlattenLayer extends DenseLayer {
@@ -19,6 +21,17 @@ public Vector forward(StatesCache cache, Layer<?, ?> lastLayer, Vector input) {
1921
return input;
2022
}
2123

24+
@Override
25+
public void propagate(StatesCache cache, Layer<?, ?> previous, Updater updater, Optimizer optimizer) {
26+
super.propagate(cache, previous, updater, optimizer);
27+
28+
for (int i = 0; i < getTotalNeurons(); i++) {
29+
double value = neurons.get(i).getValue(cache);
30+
31+
System.out.println(i + " has " + value);
32+
}
33+
}
34+
2235
public Vector flatten(StatesCache cache, Layer<?, ?> layer, Kernel input) {
2336
Preconditions.checkNotNull(input, "Last convolutional input is null! Missing an input layer?");
2437

src/main/java/net/echo/brain4j/layer/impl/recurrent/RecurrentLayer.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,6 @@ public Vector forward(StatesCache cache, Layer<?, ?> lastLayer, Vector input) {
7575
throw new UnsupportedOperationException("Previous layer must be a dense or recurrent layer!");
7676
}
7777

78-
int prevSize = lastLayer.getNeurons().size();
79-
8078
Vector hiddenState = previousTimestep.get();
8179

8280
for (int i = 0; i < neurons.size(); i++) {
@@ -108,8 +106,8 @@ public Vector forward(StatesCache cache, Layer<?, ?> lastLayer, Vector input) {
108106
}
109107

110108
@Override
111-
public void propagate(StatesCache cacheHolder, Layer<?, ?> previous, Updater updater, Optimizer optimizer) {
112-
super.propagate(cacheHolder, previous, updater, optimizer);
109+
public void propagate(StatesCache cache, Layer<?, ?> previous, Updater updater, Optimizer optimizer) {
110+
super.propagate(cache, previous, updater, optimizer);
113111
}
114112

115113
public List<Vector> getRecurrentWeights() {

src/main/java/net/echo/brain4j/structure/cache/Parameters.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
public class Parameters {
44

5-
public static int TOTAL_LAYERS;
6-
public static int TOTAL_SYNAPSES;
7-
public static int TOTAL_NEURONS;
5+
public static int TOTAL_CONV_LAYER = 0;
6+
public static int TOTAL_LAYERS = 0;
7+
public static int TOTAL_SYNAPSES = 0;
8+
public static int TOTAL_NEURONS = 0;
89
}

src/main/java/net/echo/brain4j/structure/cache/StatesCache.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
11
package net.echo.brain4j.structure.cache;
22

3+
import net.echo.brain4j.convolution.Kernel;
4+
import net.echo.brain4j.layer.impl.convolution.ConvLayer;
35
import net.echo.brain4j.structure.Neuron;
46

57
public class StatesCache {
68

9+
private final Kernel[] featureMaps;
710
private final float[] valuesCache;
811
private final float[] deltasCache;
912

1013
public StatesCache() {
1114
this.valuesCache = new float[Parameters.TOTAL_NEURONS];
1215
this.deltasCache = new float[Parameters.TOTAL_NEURONS];
16+
this.featureMaps = new Kernel[Parameters.TOTAL_CONV_LAYER];
17+
}
18+
19+
public void setFeatureMap(ConvLayer layer, Kernel output) {
20+
featureMaps[layer.getId()] = output;
21+
}
22+
23+
public Kernel getFeatureMap(ConvLayer layer) {
24+
return featureMaps[layer.getId()];
1325
}
1426

1527
public float getValue(Neuron neuron) {

src/test/java/conv/ConvExample.java

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import net.echo.brain4j.model.initialization.WeightInit;
1414
import net.echo.brain4j.training.data.DataRow;
1515
import net.echo.brain4j.training.optimizers.impl.Adam;
16+
import net.echo.brain4j.training.techniques.SmartTrainer;
17+
import net.echo.brain4j.training.techniques.TrainListener;
1618
import net.echo.brain4j.training.updater.impl.StochasticUpdater;
1719
import net.echo.brain4j.utils.DataSet;
1820
import net.echo.brain4j.utils.Vector;
@@ -25,24 +27,22 @@
2527

2628
public class ConvExample {
2729

28-
public static void main(String[] args) {
30+
public static void main(String[] args) throws IOException {
2931
ConvExample example = new ConvExample();
3032
example.start();
3133
}
3234

33-
private void start() {
34-
Model model = getModel();
35+
private void start() throws IOException {
36+
Sequential model = getModel();
3537
DataSet<DataRow> dataSet = getDataSet();
3638

37-
double loss = model.evaluate(dataSet);
38-
System.out.println("Initial loss: " + loss);
39+
model.fit(dataSet);
3940

40-
for (int i = 0; i < 1000; i++) {
41+
double loss = model.evaluate(dataSet);
42+
System.out.println("Loss: " + loss);
43+
/*for (int i = 0; i < 1000; i++) {
4144
model.fit(dataSet);
42-
43-
loss = model.evaluate(dataSet);
44-
System.out.println("Final loss: " + loss + " at " + i);
45-
}
45+
}*/
4646
}
4747

4848
private Sequential getModel() {
@@ -52,14 +52,14 @@ private Sequential getModel() {
5252

5353
// #1 convolutional block
5454
new ConvLayer(32, 3, 3, Activations.RELU),
55-
new PoolingLayer(PoolingType.MAX, 2, 2, 2),
55+
// new PoolingLayer(PoolingType.MAX, 2, 2, 2),
5656

5757
// #2 convolutional block
5858
new ConvLayer(64, 5, 5, Activations.RELU),
59-
new PoolingLayer(PoolingType.MAX, 2, 2, 2),
59+
// new PoolingLayer(PoolingType.MAX, 2, 2, 2),
6060

6161
// Flattens the feature map to a 1D vector
62-
new FlattenLayer(25), // You must find the right size by trial and error
62+
new FlattenLayer(484), // You must find the right size by trial and error
6363

6464
// Classifiers
6565
new DenseLayer(32, Activations.RELU),
@@ -69,25 +69,21 @@ private Sequential getModel() {
6969
return model.compile(WeightInit.HE, LossFunctions.CROSS_ENTROPY, new Adam(0.1), new StochasticUpdater());
7070
}
7171

72-
private DataSet<DataRow> getDataSet() {
72+
private DataSet<DataRow> getDataSet() throws IOException {
7373
DataSet<DataRow> set = new DataSet<>();
7474

75-
try {
76-
List<String> lines = FileUtils.readLines(new File("dataset.csv"), "UTF-8");
75+
List<String> lines = FileUtils.readLines(new File("dataset.csv"), "UTF-8");
7776

78-
for (String line : lines) {
79-
String[] parts = line.split(",");
80-
double[] inputs = Arrays.stream(parts, 1, parts.length).mapToDouble(x -> Double.parseDouble(x) / 255).toArray();
77+
for (String line : lines) {
78+
String[] parts = line.split(",");
79+
double[] inputs = Arrays.stream(parts, 1, parts.length).mapToDouble(x -> Double.parseDouble(x) / 255).toArray();
8180

82-
Vector output = new Vector(10);
81+
Vector output = new Vector(10);
8382

84-
int value = Integer.parseInt(parts[0]);
85-
output.set(value, 1);
83+
int value = Integer.parseInt(parts[0]);
84+
output.set(value, 1);
8685

87-
set.getData().add(new DataRow(Vector.of(inputs), output));
88-
}
89-
} catch (IOException e) {
90-
throw new RuntimeException("Error reading dataset: " + e.getMessage(), e);
86+
set.getData().add(new DataRow(Vector.of(inputs), output));
9187
}
9288

9389
return set;

src/test/java/mnist/MNISTClassifier.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ public static DataSet<DataRow> getData() {
138138
return set;
139139
}
140140

141-
private static class ExampleListener extends TrainListener {
141+
private static class ExampleListener extends TrainListener<DataRow> {
142142

143143
@Override
144144
public void onEvaluated(DataSet<DataRow> dataSet, int epoch, double loss, long took) {

0 commit comments

Comments
 (0)