Skip to content

Commit e166002

Browse files
committed
Added masked multi head attention and transformer decoder
1 parent fcfb183 commit e166002

File tree

9 files changed

+332
-119
lines changed

9 files changed

+332
-119
lines changed

build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies {
1616
implementation 'commons-io:commons-io:2.18.0'
1717
implementation 'org.jocl:jocl:2.0.5'
1818
testImplementation 'org.jfree:jfreechart:1.5.3'
19+
implementation 'org.apache.commons:commons-math3:3.6.1'
1920
}
2021

2122
java {
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
package net.echo.brain4j.transformers;
2+
3+
import net.echo.brain4j.activation.Activations;
4+
import net.echo.brain4j.layer.Layer;
5+
import net.echo.brain4j.layer.impl.DenseLayer;
6+
import net.echo.brain4j.layer.impl.LayerNorm;
7+
import net.echo.brain4j.loss.LossFunctions;
8+
import net.echo.brain4j.model.impl.Sequential;
9+
import net.echo.brain4j.model.initialization.WeightInit;
10+
import net.echo.brain4j.structure.cache.StatesCache;
11+
import net.echo.brain4j.training.optimizers.Optimizer;
12+
import net.echo.brain4j.training.updater.Updater;
13+
import net.echo.brain4j.transformers.attention.MultiHeadAttention;
14+
import net.echo.brain4j.transformers.masked.MaskedMultiHeadAttention;
15+
import net.echo.brain4j.utils.Vector;
16+
17+
import java.util.ArrayList;
18+
import java.util.List;
19+
20+
public class TransformerDecoder extends Layer<List<Vector>, List<Vector>> {
21+
22+
private final int heads;
23+
private final int dimension;
24+
private final double temperature;
25+
26+
private final Sequential feedForward;
27+
private final LayerNorm normalizer;
28+
29+
private MaskedMultiHeadAttention maskedAttention;
30+
31+
public TransformerDecoder(int numHeads, int dimension, double temperature) {
32+
super(0, Activations.LINEAR);
33+
this.heads = numHeads;
34+
this.dimension = dimension;
35+
this.temperature = temperature;
36+
37+
this.normalizer = new LayerNorm();
38+
this.feedForward = new Sequential(
39+
new DenseLayer(dimension, Activations.LINEAR),
40+
new DenseLayer(4 * dimension, Activations.GELU),
41+
new DenseLayer(dimension, Activations.LINEAR)
42+
);
43+
}
44+
45+
public int getAttentionSize() {
46+
return maskedAttention.getTotalNeurons();
47+
}
48+
49+
public int getFeedForwardSize() {
50+
return feedForward.getTotalWeights();
51+
}
52+
53+
@Override
54+
public int getTotalParams() {
55+
return getAttentionSize() + getFeedForwardSize();
56+
}
57+
58+
@Override
59+
public int getTotalNeurons() {
60+
return feedForward.getTotalNeurons();
61+
}
62+
63+
@Override
64+
public void compile(WeightInit weightInit, LossFunctions lossFunction, Optimizer optimizer, Updater updater) {
65+
this.maskedAttention = new MaskedMultiHeadAttention(weightInit, heads, dimension, temperature);
66+
this.feedForward.compile(weightInit, lossFunction, optimizer, updater);
67+
}
68+
69+
@Override
70+
public void propagate(StatesCache cache, Layer<?, ?> previous, Updater updater, Optimizer optimizer) {
71+
// feedForward.propagate(cache, this, updater, optimizer);
72+
// maskedAttention.propagate(cache, this, updater, optimizer);
73+
}
74+
75+
@Override
76+
public List<Vector> forward(StatesCache cache, Layer<?, ?> lastLayer, List<Vector> input) {
77+
List<Vector> attentionOutput = maskedAttention.attend(input);
78+
List<Vector> normAttention = new ArrayList<>();
79+
80+
for (Vector token : attentionOutput) {
81+
normAttention.add(normalizer.normalize(token));
82+
}
83+
84+
List<Vector> feedForwardOutput = new ArrayList<>();
85+
86+
for (Vector vector : normAttention) {
87+
feedForwardOutput.add(feedForward.predict(vector));
88+
}
89+
90+
List<Vector> result = new ArrayList<>();
91+
92+
for (int i = 0; i < feedForwardOutput.size(); i++) {
93+
Vector tokenFF = feedForwardOutput.get(i);
94+
95+
tokenFF.add(normAttention.get(i));
96+
result.add(normalizer.normalize(tokenFF));
97+
}
98+
99+
return result;
100+
}
101+
102+
public Sequential getFeedForward() {
103+
return feedForward;
104+
}
105+
106+
public LayerNorm getNormalizer() {
107+
return normalizer;
108+
}
109+
110+
public MultiHeadAttention getMaskedAttention() {
111+
return maskedAttention;
112+
}
113+
}

src/main/java/net/echo/brain4j/transformers/attention/AttentionHead.java

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010

1111
public class AttentionHead {
1212

13-
private final int inputDimension;
14-
private final int headDimension;
15-
private final double temperature;
13+
protected final int inputDimension;
14+
protected final int headDimension;
15+
protected final double temperature;
1616

17-
private final float[][] queryWeights;
18-
private final float[][] keyWeights;
19-
private final float[][] valueWeights;
17+
protected final float[][] queryWeights;
18+
protected final float[][] keyWeights;
19+
protected final float[][] valueWeights;
2020

2121
public AttentionHead(WeightInit weightInit, int inputDimension, int headDimension, double temperature) {
2222
this.inputDimension = inputDimension;
@@ -40,7 +40,46 @@ public int size() {
4040
return total;
4141
}
4242

43-
private void initializeWeights(WeightInit weightInit) {
43+
public List<Vector> attend(List<Vector> inputs) {
44+
int sequenceLength = inputs.size();
45+
46+
List<Vector> queries = new ArrayList<>();
47+
List<Vector> keys = new ArrayList<>();
48+
List<Vector> values = new ArrayList<>();
49+
50+
for (Vector token : inputs) {
51+
queries.add(multiply(token, queryWeights));
52+
keys.add(multiply(token, keyWeights));
53+
values.add(multiply(token, valueWeights));
54+
}
55+
56+
List<Vector> output = new ArrayList<>();
57+
double scale = Math.sqrt(headDimension);
58+
59+
for (int i = 0; i < sequenceLength; i++) {
60+
Vector query = queries.get(i);
61+
List<Double> scoreList = new ArrayList<>();
62+
63+
for (int j = 0; j < sequenceLength; j++) {
64+
double score = query.weightedSum(keys.get(j)) / scale;
65+
scoreList.add(score);
66+
}
67+
68+
Vector attentionWeights = softmax(scoreList);
69+
Vector headOutput = new Vector(headDimension);
70+
71+
for (int j = 0; j < sequenceLength; j++) {
72+
Vector weightedValue = values.get(j).scale(attentionWeights.get(j));
73+
headOutput = headOutput.add(weightedValue);
74+
}
75+
76+
output.add(headOutput);
77+
}
78+
79+
return output;
80+
}
81+
82+
protected void initializeWeights(WeightInit weightInit) {
4483
Random rng = new Random();
4584
WeightInitializer initializer = weightInit.getInitializer();
4685

@@ -55,7 +94,7 @@ private void initializeWeights(WeightInit weightInit) {
5594
}
5695
}
5796

58-
private Vector multiply(Vector vector, float[][] weights) {
97+
protected Vector multiply(Vector vector, float[][] weights) {
5998
Vector result = new Vector(headDimension);
6099

61100
for (int j = 0; j < headDimension; j++) {
@@ -70,7 +109,7 @@ private Vector multiply(Vector vector, float[][] weights) {
70109
return result;
71110
}
72111

73-
private Vector softmax(List<Double> scores) {
112+
protected Vector softmax(List<Double> scores) {
74113
Vector result = new Vector(scores.size());
75114
double maxScore = Double.NEGATIVE_INFINITY;
76115

@@ -95,43 +134,4 @@ private Vector softmax(List<Double> scores) {
95134

96135
return result;
97136
}
98-
99-
public List<Vector> attend(List<Vector> inputs) {
100-
int sequenceLength = inputs.size();
101-
102-
List<Vector> queries = new ArrayList<>();
103-
List<Vector> keys = new ArrayList<>();
104-
List<Vector> values = new ArrayList<>();
105-
106-
for (Vector token : inputs) {
107-
queries.add(multiply(token, queryWeights));
108-
keys.add(multiply(token, keyWeights));
109-
values.add(multiply(token, valueWeights));
110-
}
111-
112-
List<Vector> output = new ArrayList<>();
113-
double scale = Math.sqrt(headDimension);
114-
115-
for (int i = 0; i < sequenceLength; i++) {
116-
Vector query = queries.get(i);
117-
List<Double> scoreList = new ArrayList<>();
118-
119-
for (int j = 0; j < sequenceLength; j++) {
120-
double score = query.weightedSum(keys.get(j)) / scale;
121-
scoreList.add(score);
122-
}
123-
124-
Vector attentionWeights = softmax(scoreList);
125-
Vector headOutput = new Vector(headDimension);
126-
127-
for (int j = 0; j < sequenceLength; j++) {
128-
Vector weightedValue = values.get(j).scale(attentionWeights.get(j));
129-
headOutput = headOutput.add(weightedValue);
130-
}
131-
132-
output.add(headOutput);
133-
}
134-
135-
return output;
136-
}
137137
}
Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,25 @@
11
package net.echo.brain4j.transformers.attention;
22

33
import com.google.common.base.Preconditions;
4-
import net.echo.brain4j.activation.Activations;
5-
import net.echo.brain4j.layer.Layer;
64
import net.echo.brain4j.model.initialization.WeightInit;
75
import net.echo.brain4j.utils.Vector;
86

97
import java.util.ArrayList;
108
import java.util.List;
119
import java.util.Random;
1210

13-
public class MultiHeadAttention extends Layer<List<Vector>, List<Vector>> {
11+
public class MultiHeadAttention {
1412

1513
private final List<AttentionHead> heads;
16-
private final WeightInit weightInit;
17-
private final double temperature;
18-
private final int headCount;
19-
private final int modelDimension;
20-
private final int headDimension;
14+
protected final WeightInit weightInit;
15+
protected final double temperature;
16+
protected final int headCount;
17+
protected final int modelDimension;
18+
protected final int headDimension;
2119

22-
private final float[][] outProjectionWeights;
20+
protected final float[][] outProjectionWeights;
2321

2422
public MultiHeadAttention(WeightInit weightInit, int headCount, int modelDimension, double temperature) {
25-
super(0, Activations.LINEAR);
2623
this.weightInit = weightInit;
2724
this.headCount = headCount;
2825
this.modelDimension = modelDimension;
@@ -38,13 +35,36 @@ public MultiHeadAttention(WeightInit weightInit, int headCount, int modelDimensi
3835
initializeOutProjectionWeights();
3936
}
4037

41-
private void initializeHeads() {
42-
for (int i = 0; i < headCount; i++) {
43-
heads.add(new AttentionHead(weightInit, modelDimension, headDimension, temperature));
38+
public List<Vector> attend(List<Vector> inputs) {
39+
List<List<Vector>> headOutputs = new ArrayList<>();
40+
41+
for (AttentionHead head : heads) {
42+
headOutputs.add(head.attend(inputs));
43+
}
44+
45+
return concatenate(headOutputs, inputs);
46+
}
47+
48+
public List<Vector> concatenate(List<List<Vector>> headOutputs, List<Vector> inputs) {
49+
List<Vector> result = new ArrayList<>();
50+
51+
for (int i = 0; i < inputs.size(); i++) {
52+
List<Vector> concatList = new ArrayList<>();
53+
54+
for (List<Vector> headOutput : headOutputs) {
55+
concatList.add(headOutput.get(i));
56+
}
57+
58+
Vector concatenated = concatenateVectors(concatList);
59+
Vector projected = projectVector(concatenated);
60+
61+
projected.add(inputs.get(i));
62+
result.add(projected);
4463
}
64+
65+
return result;
4566
}
4667

47-
@Override
4868
public int getTotalNeurons() {
4969
int total = 0;
5070

@@ -57,7 +77,13 @@ public int getTotalNeurons() {
5777
return total;
5878
}
5979

60-
private void initializeOutProjectionWeights() {
80+
protected void initializeHeads() {
81+
for (int i = 0; i < headCount; i++) {
82+
heads.add(new AttentionHead(weightInit, modelDimension, headDimension, temperature));
83+
}
84+
}
85+
86+
protected void initializeOutProjectionWeights() {
6187
Random rng = new Random();
6288
double bound = weightInit.getInitializer().getBound(headCount * headDimension, modelDimension);
6389

@@ -69,7 +95,7 @@ private void initializeOutProjectionWeights() {
6995
}
7096
}
7197

72-
private Vector projectVector(Vector concatenated) {
98+
protected Vector projectVector(Vector concatenated) {
7399
Vector result = new Vector(modelDimension);
74100

75101
for (int j = 0; j < modelDimension; j++) {
@@ -85,7 +111,7 @@ private Vector projectVector(Vector concatenated) {
85111
return result;
86112
}
87113

88-
private Vector concatenateVectors(List<Vector> vectors) {
114+
protected Vector concatenateVectors(List<Vector> vectors) {
89115
int totalSize = 0;
90116

91117
for (Vector v : vectors) {
@@ -103,31 +129,4 @@ private Vector concatenateVectors(List<Vector> vectors) {
103129

104130
return concatenated;
105131
}
106-
107-
public List<Vector> attend(List<Vector> inputs) {
108-
List<List<Vector>> headOutputs = new ArrayList<>();
109-
110-
for (AttentionHead head : heads) {
111-
headOutputs.add(head.attend(inputs));
112-
}
113-
114-
int seqLen = inputs.size();
115-
List<Vector> result = new ArrayList<>();
116-
117-
for (int i = 0; i < seqLen; i++) {
118-
List<Vector> concatList = new ArrayList<>();
119-
120-
for (List<Vector> headOutput : headOutputs) {
121-
concatList.add(headOutput.get(i));
122-
}
123-
124-
Vector concatenated = concatenateVectors(concatList);
125-
Vector projected = projectVector(concatenated);
126-
127-
projected.add(inputs.get(i));
128-
result.add(projected);
129-
}
130-
131-
return result;
132-
}
133132
}

0 commit comments

Comments
 (0)