Skip to content

Commit 86e685a

Browse files
committed
More work on the pooling layer
1 parent a49fb48 commit 86e685a

File tree

7 files changed

+187
-45
lines changed

7 files changed

+187
-45
lines changed

src/main/java/net/echo/brain4j/convolution/pooling/PoolingFunction.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@
66
public interface PoolingFunction {
77

88
double apply(PoolingLayer layer, Kernel input, int i, int j);
9+
10+
void unpool(PoolingLayer layer, int outX, int outY, Kernel deltaPooling, Kernel deltaUnpooled, Kernel input);
911
}

src/main/java/net/echo/brain4j/convolution/pooling/impl/AveragePooling.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,25 @@ public double apply(PoolingLayer layer, Kernel input, int i, int j) {
3131

3232
return sum / count;
3333
}
34+
35+
@Override
36+
public void unpool(PoolingLayer layer, int outX, int outY, Kernel deltaPooling, Kernel deltaUnpooled, Kernel input) {
37+
double deltaVal = deltaPooling.getValue(outX, outY);
38+
39+
int startX = outX * layer.getStride();
40+
int startY = outY * layer.getStride();
41+
42+
int endX = Math.min(startX + layer.getKernelWidth(), input.getWidth());
43+
int endY = Math.min(startY + layer.getKernelHeight(), input.getHeight());
44+
45+
int poolArea = (endX - startX) * (endY - startY);
46+
double distributedDelta = deltaVal / poolArea;
47+
48+
for (int y = startY; y < endY; y++) {
49+
for (int x = startX; x < endX; x++) {
50+
double current = deltaUnpooled.getValue(x, y);
51+
deltaUnpooled.setValue(x, y, current + distributedDelta);
52+
}
53+
}
54+
}
3455
}

src/main/java/net/echo/brain4j/convolution/pooling/impl/MaxPooling.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,33 @@ public double apply(PoolingLayer layer, Kernel input, int i, int j) {
3030

3131
return pooledValue;
3232
}
33+
34+
@Override
35+
public void unpool(PoolingLayer layer, int outX, int outY, Kernel deltaPooling, Kernel deltaUnpooled, Kernel input) {
36+
double deltaVal = deltaPooling.getValue(outX, outY);
37+
double maxVal = Double.NEGATIVE_INFINITY;
38+
39+
int startX = outX * layer.getStride();
40+
int startY = outY * layer.getStride();
41+
42+
int endX = Math.min(startX + layer.getKernelWidth(), input.getWidth());
43+
int endY = Math.min(startY + layer.getKernelHeight(), input.getHeight());
44+
45+
int maxX = startX, maxY = startY;
46+
47+
for (int y = startY; y < endY; y++) {
48+
for (int x = startX; x < endX; x++) {
49+
double val = input.getValue(x, y);
50+
51+
if (val > maxVal) {
52+
maxVal = val;
53+
maxX = x;
54+
maxY = y;
55+
}
56+
}
57+
}
58+
59+
double current = deltaUnpooled.getValue(maxX, maxY);
60+
deltaUnpooled.setValue(maxX, maxY, current + deltaVal);
61+
}
3362
}

src/main/java/net/echo/brain4j/layer/impl/convolution/ConvLayer.java

Lines changed: 67 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44
import net.echo.brain4j.activation.Activations;
55
import net.echo.brain4j.convolution.Kernel;
66
import net.echo.brain4j.layer.Layer;
7-
import net.echo.brain4j.layer.impl.DenseLayer;
87
import net.echo.brain4j.structure.Neuron;
98
import net.echo.brain4j.structure.cache.Parameters;
109
import net.echo.brain4j.structure.cache.StatesCache;
1110
import net.echo.brain4j.training.optimizers.Optimizer;
1211
import net.echo.brain4j.training.updater.Updater;
13-
import org.checkerframework.checker.units.qual.K;
1412

1513
import java.util.ArrayList;
1614
import java.util.List;
@@ -130,50 +128,87 @@ public Kernel forward(StatesCache cache, Layer<?, ?> lastLayer, Kernel input) {
130128
public void propagate(StatesCache cache, Layer<?, ?> nextLayer, Updater updater, Optimizer optimizer) {
131129
Kernel featureMap = cache.getFeatureMap(this);
132130

133-
if (nextLayer instanceof FlattenLayer) {
134-
List<Neuron> neurons = nextLayer.getNeurons();
135-
Kernel deltaKernel = new Kernel(featureMap.getWidth(), featureMap.getHeight());
131+
switch (nextLayer) {
132+
case FlattenLayer flattenLayer -> {
133+
List<Neuron> neurons = nextLayer.getNeurons();
134+
Kernel deltaKernel = new Kernel(featureMap.getWidth(), featureMap.getHeight());
136135

137-
for (int h = 0; h < featureMap.getHeight(); h++) {
138-
for (int w = 0; w < featureMap.getWidth(); w++) {
139-
int index = h * featureMap.getWidth() + w;
140-
double deltaNeuron = neurons.get(index).getDelta(cache);
136+
for (int h = 0; h < featureMap.getHeight(); h++) {
137+
for (int w = 0; w < featureMap.getWidth(); w++) {
138+
int index = h * featureMap.getWidth() + w;
139+
double deltaNeuron = neurons.get(index).getDelta(cache);
141140

142-
double derivative = activation.getFunction().getDerivative(featureMap.getValue(w, h));
143-
double localDelta = deltaNeuron * derivative;
141+
double derivative = activation.getFunction().getDerivative(featureMap.getValue(w, h));
142+
double localDelta = deltaNeuron * derivative;
144143

145-
deltaKernel.setValue(w, h, localDelta);
144+
deltaKernel.setValue(w, h, localDelta);
145+
}
146146
}
147+
148+
updateParameters(cache, optimizer, deltaKernel);
147149
}
150+
case ConvLayer nextConvLayer -> {
151+
Kernel deltaNext = cache.getDelta(nextConvLayer);
152+
Kernel deltaCurrent = new Kernel(featureMap.getWidth(), featureMap.getHeight());
148153

149-
updateParameters(cache, optimizer, deltaKernel);
150-
} else if (nextLayer instanceof ConvLayer nextConvLayer) {
151-
Kernel deltaNext = cache.getDelta(nextConvLayer);
152-
Kernel deltaCurrent = new Kernel(featureMap.getWidth(), featureMap.getHeight());
154+
for (Kernel nextKernel : nextConvLayer.getKernels()) {
155+
Kernel rotatedKernel = nextKernel.rotate180();
156+
Kernel contribution = deltaNext.convolute(rotatedKernel, 1);
153157

154-
for (Kernel nextKernel : nextConvLayer.getKernels()) {
155-
Kernel rotatedKernel = nextKernel.rotate180();
156-
Kernel contribution = deltaNext.convolute(rotatedKernel, 1);
158+
if (contribution.getWidth() != deltaCurrent.getWidth() || contribution.getHeight() != deltaCurrent.getHeight()) {
159+
contribution = cropTo(contribution, deltaCurrent.getWidth(), deltaCurrent.getHeight());
160+
}
157161

158-
if (contribution.getWidth() != deltaCurrent.getWidth() || contribution.getHeight() != deltaCurrent.getHeight()) {
159-
contribution = cropTo(contribution, deltaCurrent.getWidth(), deltaCurrent.getHeight());
162+
deltaCurrent.add(contribution);
160163
}
161164

162-
deltaCurrent.add(contribution);
163-
}
164-
165-
for (int h = 0; h < deltaCurrent.getHeight(); h++) {
166-
for (int w = 0; w < deltaCurrent.getWidth(); w++) {
167-
double derivative = activation.getFunction().getDerivative(featureMap.getValue(w, h));
168-
double updatedDelta = clipGradient(deltaCurrent.getValue(w, h) * derivative);
165+
for (int h = 0; h < deltaCurrent.getHeight(); h++) {
166+
for (int w = 0; w < deltaCurrent.getWidth(); w++) {
167+
double derivative = activation.getFunction().getDerivative(featureMap.getValue(w, h));
168+
double updatedDelta = clipGradient(deltaCurrent.getValue(w, h) * derivative);
169169

170-
deltaCurrent.setValue(w, h, updatedDelta);
170+
deltaCurrent.setValue(w, h, updatedDelta);
171+
}
171172
}
173+
174+
updateParameters(cache, optimizer, deltaCurrent);
172175
}
176+
case PoolingLayer poolingLayer -> {
177+
Kernel deltaPooling = cache.getDelta(poolingLayer);
178+
Kernel deltaUnpooled = new Kernel(featureMap.getWidth(), featureMap.getHeight());
179+
180+
int poolWidth = poolingLayer.getKernelWidth();
181+
int poolHeight = poolingLayer.getKernelHeight();
182+
int poolStride = poolingLayer.getStride();
183+
184+
for (int ph = 0; ph < deltaPooling.getHeight(); ph++) {
185+
for (int pw = 0; pw < deltaPooling.getWidth(); pw++) {
186+
double deltaVal = deltaPooling.getValue(pw, ph);
187+
188+
int startX = pw * poolStride;
189+
int startY = ph * poolStride;
190+
191+
for (int y = startY; y < startY + poolHeight && y < featureMap.getHeight(); y++) {
192+
for (int x = startX; x < startX + poolWidth && x < featureMap.getWidth(); x++) {
193+
double current = deltaUnpooled.getValue(x, y);
194+
current += deltaVal / (poolWidth * poolHeight);
195+
deltaUnpooled.setValue(x, y, current);
196+
}
197+
}
198+
}
199+
}
200+
201+
for (int h = 0; h < deltaUnpooled.getHeight(); h++) {
202+
for (int w = 0; w < deltaUnpooled.getWidth(); w++) {
203+
double derivative = activation.getFunction().getDerivative(featureMap.getValue(w, h));
204+
double updatedDelta = clipGradient(deltaUnpooled.getValue(w, h) * derivative);
205+
deltaUnpooled.setValue(w, h, updatedDelta);
206+
}
207+
}
173208

174-
updateParameters(cache, optimizer, deltaCurrent);
175-
} else {
176-
throw new UnsupportedOperationException("Propagation not support for " + nextLayer.getClass().getSimpleName());
209+
updateParameters(cache, optimizer, deltaUnpooled);
210+
}
211+
default -> throw new UnsupportedOperationException("Propagation not support for " + nextLayer.getClass().getSimpleName());
177212
}
178213
}
179214

src/main/java/net/echo/brain4j/layer/impl/convolution/PoolingLayer.java

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
import net.echo.brain4j.convolution.Kernel;
66
import net.echo.brain4j.convolution.pooling.PoolingType;
77
import net.echo.brain4j.layer.Layer;
8+
import net.echo.brain4j.structure.Neuron;
9+
import net.echo.brain4j.structure.cache.Parameters;
810
import net.echo.brain4j.structure.cache.StatesCache;
11+
import net.echo.brain4j.training.optimizers.Optimizer;
12+
import net.echo.brain4j.training.updater.Updater;
913

1014
public class PoolingLayer extends Layer<Kernel, Kernel> {
1115

@@ -25,6 +29,7 @@ public PoolingLayer(PoolingType poolingType, int kernelWidth, int kernelHeight,
2529

2630
public PoolingLayer(PoolingType poolingType, int kernelWidth, int kernelHeight, int stride, int padding) {
2731
super(kernelWidth * kernelHeight, Activations.LINEAR);
32+
this.id = Parameters.TOTAL_CONV_LAYER++;
2833
this.poolingType = poolingType;
2934
this.kernelHeight = kernelHeight;
3035
this.kernelWidth = kernelWidth;
@@ -56,8 +61,56 @@ public Kernel forward(StatesCache cache, Layer<?, ?> lastLayer, Kernel input) {
5661
}
5762
}
5863

64+
cache.setFeatureMap(this, output);
65+
cache.setInput(this, input);
66+
5967
return output;
6068
}
69+
@Override
70+
public void propagate(StatesCache cache, Layer<?, ?> nextLayer, Updater updater, Optimizer optimizer) {
71+
System.out.println("Layer id: " + id);
72+
Kernel output = cache.getFeatureMap(this);
73+
Kernel input = cache.getInput(this);
74+
75+
Kernel deltaPooling = new Kernel(output.getWidth(), output.getHeight());
76+
77+
if (nextLayer instanceof ConvLayer convLayer) {
78+
System.out.println("Getting pooling from conv");
79+
deltaPooling = cache.getDelta(convLayer);
80+
} else if (nextLayer instanceof FlattenLayer flattenLayer) {
81+
int outW = output.getWidth();
82+
int outH = output.getHeight();
83+
84+
for (int h = 0; h < outH; h++) {
85+
for (int w = 0; w < outW; w++) {
86+
int index = h * outW + w;
87+
88+
Neuron neuron = flattenLayer.getNeuronAt(index);
89+
double neuronDelta = neuron.getDelta(cache);
90+
91+
deltaPooling.setValue(w, h, neuronDelta);
92+
}
93+
}
94+
} else {
95+
throw new UnsupportedOperationException("Unsupported layer after pooling layer!");
96+
}
97+
98+
Kernel deltaUnpooled = new Kernel(input.getWidth(), input.getHeight());
99+
100+
System.out.println("OutX: " + output.getWidth());
101+
System.out.println("OutY: " + output.getHeight());
102+
103+
System.out.println("DX: " + deltaPooling.getWidth());
104+
System.out.println("DY: " + deltaPooling.getHeight());
105+
for (int outY = 0; outY < output.getHeight(); outY++) {
106+
for (int outX = 0; outX < output.getWidth(); outX++) {
107+
poolingType.getFunction().unpool(this, outX, outY, deltaPooling, deltaUnpooled, input);
108+
}
109+
}
110+
111+
cache.setDelta(this, deltaUnpooled);
112+
}
113+
61114

62115
public PoolingType getPoolingType() {
63116
return poolingType;

src/main/java/net/echo/brain4j/structure/cache/StatesCache.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package net.echo.brain4j.structure.cache;
22

33
import net.echo.brain4j.convolution.Kernel;
4-
import net.echo.brain4j.layer.impl.convolution.ConvLayer;
4+
import net.echo.brain4j.layer.Layer;
55
import net.echo.brain4j.structure.Neuron;
66

77
public class StatesCache {
@@ -15,32 +15,33 @@ public class StatesCache {
1515
public StatesCache() {
1616
this.valuesCache = new float[Parameters.TOTAL_NEURONS];
1717
this.deltasCache = new float[Parameters.TOTAL_NEURONS];
18+
1819
this.inputMap = new Kernel[Parameters.TOTAL_CONV_LAYER];
1920
this.featureMaps = new Kernel[Parameters.TOTAL_CONV_LAYER];
2021
this.deltaMap = new Kernel[Parameters.TOTAL_CONV_LAYER];
2122
}
2223

23-
public void setInput(ConvLayer layer, Kernel input) {
24+
public void setInput(Layer<Kernel, Kernel> layer, Kernel input) {
2425
inputMap[layer.getId()] = input;
2526
}
2627

27-
public Kernel getInput(ConvLayer layer) {
28+
public Kernel getInput(Layer<Kernel, Kernel> layer) {
2829
return inputMap[layer.getId()];
2930
}
3031

31-
public void setFeatureMap(ConvLayer layer, Kernel output) {
32+
public void setFeatureMap(Layer<Kernel, Kernel> layer, Kernel output) {
3233
featureMaps[layer.getId()] = output;
3334
}
3435

35-
public Kernel getFeatureMap(ConvLayer layer) {
36+
public Kernel getFeatureMap(Layer<Kernel, Kernel> layer) {
3637
return featureMaps[layer.getId()];
3738
}
3839

39-
public Kernel getDelta(ConvLayer layer) {
40+
public Kernel getDelta(Layer<Kernel, Kernel> layer) {
4041
return deltaMap[layer.getId()];
4142
}
4243

43-
public void setDelta(ConvLayer layer, Kernel delta) {
44+
public void setDelta(Layer<Kernel, Kernel> layer, Kernel delta) {
4445
deltaMap[layer.getId()] = delta;
4546
}
4647

src/test/java/conv/ConvExample.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ private void start() throws IOException {
4242
System.out.println(model.getStats());
4343
model.fit(dataSet);
4444

45-
for (int i = 0; i < 100; i++) {
45+
for (int i = 0; i < 1; i++) {
4646
long start = System.nanoTime();
4747
model.fit(dataSet);
4848
double took = (System.nanoTime() - start) / 1e6;
@@ -79,15 +79,16 @@ private Sequential getModel() {
7979
new InputLayer(28, 28),
8080

8181
// #1 convolutional block
82-
new ConvLayer(32, 3, 3, Activations.MISH),
83-
// new PoolingLayer(PoolingType.MAX, 2, 2, 2),
82+
new ConvLayer(20, 3, 3, Activations.MISH),
83+
new PoolingLayer(PoolingType.MAX, 2, 2, 2),
8484

8585
// #2 convolutional block
8686
new ConvLayer(16, 5, 5, Activations.MISH),
87-
// new PoolingLayer(PoolingType.MAX, 2, 2, 2),
87+
88+
new PoolingLayer(PoolingType.MAX, 2, 2, 2),
8889

8990
// Flattens the feature map to a 1D vector
90-
new FlattenLayer(484), // You must find the right size by trial and error
91+
new FlattenLayer(25), // You must find the right size by trial and error
9192

9293
// Classifiers
9394
new DenseLayer(32, Activations.MISH),
@@ -102,7 +103,7 @@ private DataSet<DataRow> getDataSet() throws IOException {
102103

103104
List<String> lines = FileUtils.readLines(new File("dataset.csv"), "UTF-8");
104105

105-
int max = 1500, i = 0;
106+
int max = 1, i = 0;
106107

107108
for (String line : lines) {
108109
i++;

0 commit comments

Comments
 (0)