Skip to content

Commit 81219d8

Browse files
some fixes
1 parent ce6cc28 commit 81219d8

File tree

3 files changed

+45
-41
lines changed

3 files changed

+45
-41
lines changed

models/bert/bert.hpp

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ namespace ann /** Artificial Neural Network. */ {
3030
* arma::sp_mat or arma::cube).
3131
*/
3232
template <
33-
typename OutputLayerType = NegativeLogLikelihood,
34-
typename InitType = XavierInitialization,
33+
typename OutputLayerType = NegativeLogLikelihood<>,
34+
typename InitializationRuleType = XavierInitialization,
3535
typename InputDataType = arma::mat,
3636
typename OutputDataType = arma::mat
3737
>
@@ -41,21 +41,23 @@ class BERT
4141
BERT();
4242

4343
/**
44-
* Create the TransformerDecoder object using the specified parameters.
44+
* Create the BERT object using the specified parameters.
4545
*
46-
* @param vocabSize The size of the vocabulary.
46+
* @param srcVocabSize The size of the vocabulary.
47+
* @param srcSeqLen The source sequence length.
48+
* @param numEncoderLayers The number of Transformer Encoder layers.
4749
* @param dModel The dimensionality of the model.
4850
* @param numHeads The number of attention heads.
49-
* @param numLayers The number of Transformer Encoder layers.
5051
* @param dropout The dropout rate.
51-
* @param maxSequenceLength The maximum sequence length in the given input.
52+
* @param attentionMask The attention mask used to black-out future sequences.
53+
* @param keyPaddingMask Blacks out specific tokens.
5254
*/
53-
BERT(const size_t vocabSize,
55+
BERT(const size_t srcVocabSize,
56+
const size_t srcSeqLen,
57+
const size_t numEncoderLayers = 12,
5458
const size_t dModel = 512,
5559
const size_t numHeads = 8,
56-
const size_t numLayers = 12,
5760
const double dropout = 0.1,
58-
const size_t maxSequenceLength = 5000,
5961
const InputDataType& attentionMask = InputDataType(),
6062
const InputDataType& keyPaddingMask = InputDataType());
6163

@@ -75,7 +77,13 @@ class BERT
7577

7678
private:
7779
//! Locally-stored size of the vocabulary.
78-
size_t vocabSize;
80+
size_t srcVocabSize;
81+
82+
//! Locally-stored source sequence length.
83+
size_t srcSeqLen;
84+
85+
//! Locally-stored number of Transformer Encoder blocks.
86+
size_t numEncoderLayers;
7987

8088
//! Locally-stored dimensionality of the model.
8189
size_t dModel;
@@ -86,26 +94,17 @@ class BERT
8694
//! Locally-stored number of hidden units in FFN.
8795
size_t dimFFN;
8896

89-
//! Locally-stored number of Transformer Encoder blocks.
90-
size_t numLayers;
91-
9297
//! Locally-stored dropout rate.
9398
double dropout;
9499

95-
//! Locally-stored maximum sequence length.
96-
size_t maxSequenceLength;
97-
98100
//! Locally-stored attention mask.
99101
InputDataType attentionMask;
100102

101103
//! Locally-stored key padding mask.
102104
InputDataType keyPaddingMask;
103105

104-
//! Locally-stored BERT embedding layer.
105-
LayerTypes<> embedding;
106-
107106
//! Locally-stored complete decoder network.
108-
FFN<OutputLayerType, InitType> bert;
107+
FFN<OutputLayerType, InitializationRuleType> bert;
109108
}; // class BERT
110109

111110
} // namespace ann

models/bert/bert_impl.hpp

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,52 +22,60 @@ namespace ann /** Artificial Neural Network. */ {
2222
template<typename OutputLayerType, typename InitType, typename InputDataType,
2323
typename OutputDataType>
2424
BERT<OutputLayerType, InitType, InputDataType, OutputDataType>::BERT() :
25-
vocabSize(0),
25+
srcVocabSize(0),
26+
srcSeqLen(0),
27+
numEncoderLayers(0),
2628
dModel(0),
2729
numHeads(0),
2830
dimFFN(4 * dModel),
29-
numLayers(0),
30-
dropout(0),
31-
maxSequenceLength(5000),
31+
dropout(0.0)
3232
{
3333
// Nothing to do here.
3434
}
3535

3636
template<typename OutputLayerType, typename InitType, typename InputDataType,
3737
typename OutputDataType>
3838
BERT<OutputLayerType, InitType, InputDataType, OutputDataType>::BERT(
39-
const size_t vocabSize,
39+
const size_t srcVocabSize,
40+
const size_t srcSeqLen,
41+
const size_t numEncoderLayers,
4042
const size_t dModel,
4143
const size_t numHeads,
42-
const size_t numLayers,
4344
const double dropout,
44-
const size_t maxSequenceLength,
4545
const InputDataType& attentionMask,
4646
const InputDataType& keyPaddingMask) :
47-
vocabSize(vocabSize)
47+
srcVocabSize(srcVocabSize),
48+
srcSeqLen(srcSeqLen),
49+
numEncoderLayers(numEncoderLayers),
4850
dModel(dModel),
4951
numHeads(numHeads),
5052
dimFFN(4 * dModel),
51-
numLayers(numLayers),
5253
dropout(dropout),
53-
maxSequenceLength(maxSequenceLength),
5454
attentionMask(attentionMask),
5555
keyPaddingMask(keyPaddingMask)
5656
{
57-
embedding = new AddMerge<>();
58-
embedding.Add<Lookup<>>(vocabSize, dModel);
59-
embedding.Add<Lookup<>>(3, dModel);
57+
AddMerge<>* embedding = new AddMerge<>();
58+
embedding->Add<Lookup<>>(vocabSize, dModel);
59+
embedding->Add<Lookup<>>(3, dModel);
6060

6161
bert.Add(embedding);
62-
bert.Add<PositionalEncoding<>>(dModel, maxSequenceLength);
62+
bert.Add<PositionalEncoding<>>(dModel, srcSeqLen);
6363
bert.Add<Dropout<>>(dropout);
6464

6565
for (size_t i = 0; i < numLayers; ++i)
6666
{
67-
TransformerEncoder<> enc(dModel, numHeads, dimFFN, dropout);
68-
enc.AttentionMask() = attentionMask;
69-
enc.KeyPaddingMask() = keyPaddingMask;
70-
bert.Add(enc);
67+
mlpack::ann::TransformerEncoder<> encoder(
68+
numEncoderLayers,
69+
srcSeqLen,
70+
dModel,
71+
numHeads,
72+
dimFFN,
73+
dropout,
74+
attentionMask,
75+
keyPaddingMask
76+
);
77+
78+
bert.Add(encoder.Model());
7179
}
7280
}
7381

models/bert/bert_tokenizer.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,4 @@ class BertTokenizer
127127
} // namespace ann
128128
} // namespace mlpack
129129

130-
// Include implementation.
131-
#include "bert_tokenizer_impl.hpp"
132-
133130
#endif

0 commit comments

Comments
 (0)