@@ -30,8 +30,8 @@ namespace ann /** Artificial Neural Network. */ {
3030 * arma::sp_mat or arma::cube).
3131 */
3232template <
33- typename OutputLayerType = NegativeLogLikelihood,
34- typename InitType = XavierInitialization,
33+ typename OutputLayerType = NegativeLogLikelihood<> ,
34+ typename InitializationRuleType = XavierInitialization,
3535 typename InputDataType = arma::mat,
3636 typename OutputDataType = arma::mat
3737>
@@ -41,21 +41,23 @@ class BERT
4141 BERT ();
4242
4343 /* *
44- * Create the TransformerDecoder object using the specified parameters.
44+ * Create the BERT object using the specified parameters.
4545 *
46- * @param vocabSize The size of the vocabulary.
46+ * @param srcVocabSize The size of the vocabulary.
47+ * @param srcSeqLen The source sequence length.
48+ * @param numEncoderLayers The number of Transformer Encoder layers.
4749 * @param dModel The dimensionality of the model.
4850 * @param numHeads The number of attention heads.
49- * @param numLayers The number of Transformer Encoder layers.
5051 * @param dropout The dropout rate.
51- * @param maxSequenceLength The maximum sequence length in the given input.
52+ * @param attentionMask The attention mask used to black-out future sequences.
53+ * @param keyPaddingMask Blacks out specific tokens.
5254 */
53- BERT (const size_t vocabSize,
55+ BERT (const size_t srcVocabSize,
56+ const size_t srcSeqLen,
57+ const size_t numEncoderLayers = 12 ,
5458 const size_t dModel = 512 ,
5559 const size_t numHeads = 8 ,
56- const size_t numLayers = 12 ,
5760 const double dropout = 0.1 ,
58- const size_t maxSequenceLength = 5000 ,
5961 const InputDataType& attentionMask = InputDataType(),
6062 const InputDataType& keyPaddingMask = InputDataType());
6163
@@ -75,7 +77,13 @@ class BERT
7577
7678 private:
7779 // ! Locally-stored size of the vocabulary.
78- size_t vocabSize;
80+ size_t srcVocabSize;
81+
82+ // ! Locally-stored source sequence length.
83+ size_t srcSeqLen;
84+
85+ // ! Locally-stored number of Transformer Encoder blocks.
86+ size_t numEncoderLayers;
7987
8088 // ! Locally-stored dimensionality of the model.
8189 size_t dModel;
@@ -86,26 +94,17 @@ class BERT
8694 // ! Locally-stored number of hidden units in FFN.
8795 size_t dimFFN;
8896
89- // ! Locally-stored number of Transformer Encoder blocks.
90- size_t numLayers;
91-
9297 // ! Locally-stored dropout rate.
9398 double dropout;
9499
95- // ! Locally-stored maximum sequence length.
96- size_t maxSequenceLength;
97-
98100 // ! Locally-stored attention mask.
99101 InputDataType attentionMask;
100102
101103 // ! Locally-stored key padding mask.
102104 InputDataType keyPaddingMask;
103105
104- // ! Locally-stored BERT embedding layer.
105- LayerTypes<> embedding;
106-
107106 // ! Locally-stored complete decoder network.
108- FFN<OutputLayerType, InitType > bert;
107+ FFN<OutputLayerType, InitializationRuleType > bert;
109108}; // class BERT
110109
111110} // namespace ann
0 commit comments