Skip to content

Commit ef75886

Browse files
committed
general code refactors
1 parent 1ae38b1 commit ef75886

File tree

23 files changed

+317
-243
lines changed

23 files changed

+317
-243
lines changed

brain4j-core/src/main/java/org/brain4j/core/layer/Layer.java

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,27 @@ public void deserialize(DataInputStream stream) throws Exception {
106106
}
107107
}
108108

109-
public void compile(WeightInitializer weightInit, LossFunction lossFunction, Optimizer optimizer, Updater updater) {
109+
public void compile(
110+
WeightInitializer weightInit,
111+
Optimizer optimizer,
112+
Updater updater
113+
) {
110114
this.weightInit = weightInit;
111-
this.lossFunction = lossFunction;
112115
this.optimizer = optimizer;
113116
this.updater = updater;
114117
}
115118

116-
public Tensor computeLoss(StatesCache cache, Tensor targets, Tensor outputs, LossFunction lossFunction) {
119+
public Tensor computeLoss(
120+
int index,
121+
StatesCache cache,
122+
Tensor targets,
123+
Tensor outputs,
124+
LossFunction lossFunction
125+
) {
117126
Tensor error = outputs.minus(targets);
118127
Tensor derivatives = activation.getDerivative(outputs);
119128

120-
Tensor input = cache.getInputTensor(this);
129+
Tensor input = cache.getInputTensor(index);
121130
Tensor delta = lossFunction.getDelta(error, derivatives);
122131

123132
Tensor weightsGradient = input.transpose().matmul(delta);
@@ -128,7 +137,11 @@ public Tensor computeLoss(StatesCache cache, Tensor targets, Tensor outputs, Los
128137
return delta;
129138
}
130139

131-
public void connect(Random generator, Layer previous, double bound) {
140+
public void connect(
141+
Random generator,
142+
Layer previous,
143+
double bound
144+
) {
132145
if (previous == null) return;
133146

134147
int input = previous.getTotalNeurons();
@@ -145,9 +158,14 @@ public void connect(Random generator, Layer previous, double bound) {
145158
}
146159
}
147160

148-
public abstract Tensor forward(StatesCache cache, Tensor input, boolean training);
161+
public abstract Tensor forward(
162+
int index,
163+
StatesCache cache,
164+
Tensor input,
165+
boolean training
166+
);
149167

150-
public Tensor backward(StatesCache cache, Layer previous, Tensor delta) {
168+
public Tensor backward(int index, StatesCache cache, Layer previous, Tensor delta) {
151169
throw new UnsupportedOperationException("Not implemented for " + this.getClass().getSimpleName());
152170
}
153171

brain4j-core/src/main/java/org/brain4j/core/layer/impl/BatchNorm.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,12 @@ public boolean canPropagate() {
5252
}
5353

5454
@Override
55-
public Tensor forward(StatesCache cache, Tensor input, boolean training) {
55+
public Tensor forward(
56+
int index,
57+
StatesCache cache,
58+
Tensor input,
59+
boolean training
60+
) {
5661
int batchSize = input.shape()[0];
5762

5863
Tensor transposed = input.transpose(); // [dimension, batch_size]

brain4j-core/src/main/java/org/brain4j/core/layer/impl/DenseLayer.java

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@ public String getLayerName() {
3838
}
3939

4040
@Override
41-
public Tensor forward(StatesCache cache, Tensor input, boolean training) {
41+
public Tensor forward(
42+
int index,
43+
StatesCache cache,
44+
Tensor input,
45+
boolean training
46+
) {
4247
Tensor Z = input.matmul(weights); // [batch_size, n_out]
4348

4449
int batchSize = Z.shape()[0];
@@ -54,7 +59,7 @@ public Tensor forward(StatesCache cache, Tensor input, boolean training) {
5459
}
5560

5661
if (nextLayer instanceof LayerNorm layerNorm) {
57-
Z = layerNorm.forward(cache, Z, training);
62+
Z = layerNorm.forward(index, cache, Z, training);
5863
}
5964

6065
Tensor activated = activation.activate(Z);
@@ -64,13 +69,18 @@ public Tensor forward(StatesCache cache, Tensor input, boolean training) {
6469
}
6570

6671
@Override
67-
public Tensor backward(StatesCache cache, Layer previous, Tensor delta) {
68-
Tensor input = cache.getInputTensor(this);
69-
Tensor output = cache.getOutputTensor(this);
72+
public Tensor backward(
73+
int index,
74+
StatesCache cache,
75+
Layer previous,
76+
Tensor delta
77+
) {
78+
Tensor input = cache.getInputTensor(index);
79+
Tensor output = cache.getOutputTensor(index);
7080
Tensor derivative = activation.getDerivative(output); // [batch_size, n_out]
7181

72-
Tensor weightsNext = previous.getWeights(); // [n_out, n_out_next]
73-
Tensor deltaProjected = delta.matmul(weightsNext.transpose()); // [batch_size x n_out]
82+
Tensor weightsNext = previous.getWeights(); // [n_out, n_out_next]
83+
Tensor deltaProjected = delta.matmul(weightsNext.transpose()); // [batch_size x n_out]
7484

7585
Tensor deltaThisLayer = deltaProjected.mul(derivative); // [batch_size x n_out]
7686

brain4j-core/src/main/java/org/brain4j/core/layer/impl/DropoutLayer.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,18 @@ public boolean canPropagate() {
6767
* during training. Meanwhile it will scale the input tensor by {@code 1 - dropout} during inferencing.
6868
*
6969
* @param nextLayer
70+
* @param index
7071
* @param input The input tensor.
7172
* @param training If it's called during training.
7273
* @return The resulting tensor.
7374
*/
7475
@Override
75-
public Tensor forward(StatesCache cache, Tensor input, boolean training) {
76+
public Tensor forward(
77+
int index,
78+
StatesCache cache,
79+
Tensor input,
80+
boolean training
81+
) {
7682
if (training) {
7783
return scale(input);
7884
}

brain4j-core/src/main/java/org/brain4j/core/layer/impl/LayerNorm.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,12 @@ public boolean canPropagate() {
5757
}
5858

5959
@Override
60-
public Tensor forward(StatesCache cache, Tensor input, boolean training) {
60+
public Tensor forward(
61+
int index,
62+
StatesCache cache,
63+
Tensor input,
64+
boolean training
65+
) {
6166
int batchSize = input.shape()[0];
6267
Tensor result = input.clone();
6368

brain4j-core/src/main/java/org/brain4j/core/layer/impl/conv/ConvLayer.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ public class ConvLayer extends Layer {
2020
private int padding;
2121
private int stride;
2222

23-
public ConvLayer(Activations activation, int filters, int filtersWidth, int filtersHeight) {
23+
public ConvLayer(Activations activation, int filters, int filtersWidth, int filtersHeight
24+
) {
2425
this(activation.getFunction(), filters, filtersWidth, filtersHeight, 1, 1);
2526
}
2627

@@ -36,7 +37,11 @@ public ConvLayer(Activation activation, int filters, int filtersWidth, int filte
3637
}
3738

3839
@Override
39-
public void connect(Random generator, Layer previous, double bound) {
40+
public void connect(
41+
Random generator,
42+
Layer previous,
43+
double bound
44+
) {
4045
if (previous instanceof ConvLayer convLayer) {
4146
this.channels = convLayer.getChannels();
4247
} else if (previous instanceof InputLayer inputLayer) {
@@ -79,7 +84,7 @@ public void deserialize(DataInputStream stream) throws Exception {
7984
}
8085

8186
@Override
82-
public Tensor forward(StatesCache cache, Tensor input, boolean training) {
87+
public Tensor forward(int index, StatesCache cache, Tensor input, boolean training) {
8388
// [batch_size, channels, height, width]
8489
return input.convolve(weights)
8590
.map(x -> activation.activate(x));

brain4j-core/src/main/java/org/brain4j/core/layer/impl/conv/FlattenLayer.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,7 @@
77
public class FlattenLayer extends Layer {
88

99
@Override
10-
public int getTotalNeurons() {
11-
return super.getTotalNeurons();
12-
}
13-
14-
@Override
15-
public Tensor forward(StatesCache cache, Tensor input, boolean training) {
10+
public Tensor forward(int index, StatesCache cache, Tensor input, boolean training) {
1611
return input.reshape(1, input.elements());
1712
}
1813
}

brain4j-core/src/main/java/org/brain4j/core/layer/impl/conv/InputLayer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public InputLayer(int width, int height, int channels) {
1717
}
1818

1919
@Override
20-
public Tensor forward(StatesCache cache, Tensor input, boolean training) {
20+
public Tensor forward(int index, StatesCache cache, Tensor input, boolean training) {
2121
return input;
2222
}
2323

brain4j-core/src/main/java/org/brain4j/core/layer/impl/conv/pooling/AveragePooling.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public boolean canConnect() {
2626
}
2727

2828
@Override
29-
public Tensor forward(StatesCache cache, Tensor input, boolean training) {
29+
public Tensor forward(int index, StatesCache cache, Tensor input, boolean training) {
3030
return null;
3131
}
3232

brain4j-core/src/main/java/org/brain4j/core/layer/impl/conv/pooling/MaxPooling.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ public boolean canConnect() {
2626
}
2727

2828
@Override
29-
public Tensor forward(StatesCache cache, Tensor input, boolean training) {
29+
public Tensor forward(int index, StatesCache cache, Tensor input, boolean training) {
3030
return null;
3131
}
3232

0 commit comments

Comments
 (0)