Skip to content

Commit 39da2c8

Browse files
add function to load vocabulary
1 parent 278dfc7 commit 39da2c8

File tree

6 files changed

+110
-29
lines changed

6 files changed

+110
-29
lines changed

models/bert/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ set(SOURCES
88
bert.hpp
99
bert_impl.hpp
1010
bert_tokenizer.hpp
11+
bert_tokenizer_impl.hpp
1112
)
1213

1314
foreach(file ${SOURCES})

models/bert/bert.hpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,12 @@ namespace ann /** Artificial Neural Network. */ {
2323

2424
/**
2525
* @tparam OutputLayerType Type of the last layer to be added to BERT model.
26-
* @tparam InitType Initilization Rule to be used to initialize parameters.
27-
* @tparam InputDataType Type of the input data (arma::colvec, arma::mat,
28-
* arma::sp_mat or arma::cube).
29-
* @tparam OutputDataType Type of the output data (arma::colvec, arma::mat,
30-
* arma::sp_mat or arma::cube).
26+
* @tparam InitializationRuleType Initilization Rule to be used to initialize
27+
* parameters.
3128
*/
3229
template <
3330
typename OutputLayerType = NegativeLogLikelihood<>,
34-
typename InitializationRuleType = XavierInitialization,
35-
typename InputDataType = arma::mat,
36-
typename OutputDataType = arma::mat
31+
typename InitializationRuleType = XavierInitialization
3732
>
3833
class BERT
3934
{
@@ -58,8 +53,8 @@ class BERT
5853
const size_t dModel = 512,
5954
const size_t numHeads = 8,
6055
const double dropout = 0.1,
61-
const InputDataType& attentionMask = InputDataType(),
62-
const InputDataType& keyPaddingMask = InputDataType());
56+
const arma::mat& attentionMask = arma::mat(),
57+
const arma::mat& keyPaddingMask = arma::mat());
6358

6459
/**
6560
* Load the network from a local directory.
@@ -98,10 +93,10 @@ class BERT
9893
double dropout;
9994

10095
//! Locally-stored attention mask.
101-
InputDataType attentionMask;
96+
arma::mat attentionMask;
10297

10398
//! Locally-stored key padding mask.
104-
InputDataType keyPaddingMask;
99+
arma::mat keyPaddingMask;
105100

106101
//! Locally-stored complete decoder network.
107102
FFN<OutputLayerType, InitializationRuleType> bert;

models/bert/bert_impl.hpp

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@
1919
namespace mlpack {
2020
namespace ann /** Artificial Neural Network. */ {
2121

22-
template<typename OutputLayerType, typename InitType, typename InputDataType,
23-
typename OutputDataType>
24-
BERT<OutputLayerType, InitType, InputDataType, OutputDataType>::BERT() :
22+
template<typename OutputLayerType, typename InitializationRuleType>
23+
BERT<OutputLayerType, InitializationRuleType>::BERT() :
2524
srcVocabSize(0),
2625
srcSeqLen(0),
2726
numEncoderLayers(0),
@@ -33,17 +32,16 @@ BERT<OutputLayerType, InitType, InputDataType, OutputDataType>::BERT() :
3332
// Nothing to do here.
3433
}
3534

36-
template<typename OutputLayerType, typename InitType, typename InputDataType,
37-
typename OutputDataType>
38-
BERT<OutputLayerType, InitType, InputDataType, OutputDataType>::BERT(
35+
template<typename OutputLayerType, typename InitializationRuleType>
36+
BERT<OutputLayerType, InitializationRuleType>::BERT(
3937
const size_t srcVocabSize,
4038
const size_t srcSeqLen,
4139
const size_t numEncoderLayers,
4240
const size_t dModel,
4341
const size_t numHeads,
4442
const double dropout,
45-
const InputDataType& attentionMask,
46-
const InputDataType& keyPaddingMask) :
43+
const arma::mat& attentionMask,
44+
const arma::mat& keyPaddingMask) :
4745
srcVocabSize(srcVocabSize),
4846
srcSeqLen(srcSeqLen),
4947
numEncoderLayers(numEncoderLayers),
@@ -72,25 +70,22 @@ BERT<OutputLayerType, InitType, InputDataType, OutputDataType>::BERT(
7270
dimFFN,
7371
dropout,
7472
attentionMask,
75-
keyPaddingMask
76-
);
73+
keyPaddingMask);
7774

7875
bert.Add(encoder.Model());
7976
}
8077
}
8178

82-
template<typename OutputLayerType, typename InitType, typename InputDataType,
83-
typename OutputDataType>
84-
void BERT<OutputLayerType, InitType, InputDataType, OutputDataType>::LoadModel(
79+
template<typename OutputLayerType, typename InitializationRuleType>
80+
void BERT<OutputLayerType, InitializationRuleType>::LoadModel(
8581
const std::string& filepath)
8682
{
8783
data::Load(filepath, "BERT", bert);
8884
std::cout << "Loaded model" << std::endl;
8985
}
9086

91-
template<typename OutputLayerType, typename InitType, typename InputDataType,
92-
typename OutputDataType>
93-
void BERT<OutputLayerType, InitType, InputDataType, OutputDataType>::SaveModel(
87+
template<typename OutputLayerType, typename InitializationRuleType>
88+
void BERT<OutputLayerType, InitializationRuleType>::SaveModel(
9489
const std::string& filepath)
9590
{
9691
std::cout << "Saving model" << std::endl;

models/bert/bert_tokenizer.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ template <
4444
class BertTokenizer
4545
{
4646
public:
47+
/**
48+
* Create a BertTokenizer object.
49+
*/
4750
BertTokenizer();
4851

4952
/**
@@ -124,6 +127,8 @@ class BertTokenizer
124127
std::string maskToken;
125128
}; // class BertTokenizer
126129

130+
// Include implementation.
131+
#include "bert_tokenizer_impl.hpp"
127132
} // namespace ann
128133
} // namespace mlpack
129134

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/**
2+
* @file models/bert/bert_tokenizer_impl.hpp
3+
* @author Mrityunjay Tripathi
4+
*
5+
* Implementation of the BERT Tokenizer.
6+
*
7+
* mlpack is free software; you may redistribute it and/or modify it under the
8+
* terms of the 3-clause BSD license. You should have received a copy of the
9+
* 3-clause BSD license along with mlpack. If not, see
10+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
11+
*/
12+
13+
#ifndef MODELS_BERT_BERT_TOKENIZER_IMPL_HPP
14+
#define MODELS_BERT_BERT_TOKENIZER_IMPL_HPP
15+
16+
#include "bert_tokenizer.hpp"
17+
18+
namespace mlpack {
19+
namespace ann /** Artificial Neural Network. */ {
20+
21+
template<typename InputDataType, typename OutputDataType>
22+
BertTokenizer<InputDataType, OutputDataType>::BertTokenizer() :
23+
vocabFile(""),
24+
lowerCase(true),
25+
basicTokenize(true),
26+
unkToken("[UNK]"),
27+
sepToken("[SEP]"),
28+
padToken("[PAD]"),
29+
clsToken("[CLS]"),
30+
maskToken("[MASK]")
31+
{
32+
// Nothing to do here.
33+
}
34+
35+
template<typename InputDataType, typename OutputDataType>
36+
BertTokenizer<InputDataType, OutputDataType>::BertTokenizer(
37+
const std::string vocabFile,
38+
const bool lowerCase,
39+
const bool basicTokenize,
40+
const std::vector<std::string> neverSplit,
41+
const std::string unkToken,
42+
const std::string sepToken,
43+
const std::string padToken,
44+
const std::string clsToken,
45+
const std::string maskToken) :
46+
vocabFile(vocabFile),
47+
lowerCase(lowerCase),
48+
basicTokenize(basicTokenize),
49+
neverSplit(neverSplit),
50+
unkToken(unkToken),
51+
sepToken(sepToken),
52+
padToken(padToken),
53+
clsToken(clsToken),
54+
maskToken(maskToken)
55+
{
56+
// code here.
57+
}
58+
59+
} // namespace ann
60+
} // namespace mlpack
61+
62+
#endif

utils/utils.hpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class Utils
2929
public:
3030
/**
3131
* Determines whether a path exists.
32-
*
32+
*
3333
* @param path Global or relative path.
3434
* @param absolutePath Boolean to determine if path is absolute or relative.
3535
* @return true if path exists else false. Defaults to false.
@@ -329,5 +329,28 @@ class Utils
329329
mlpack::Log::Warn << "The " << path << " doesn't exist." << std::endl;
330330
}
331331
}
332+
333+
/**
334+
* Loads a vocabulary file and stores the content into a vector.
335+
*
336+
* @param vocabPath The path to the vocabulary file.
337+
* @param vocabulary Stores the vocabulary content along with indices.
338+
*/
339+
void LoadVocabulary(const std::string vocabPath,
340+
std::map<std::string, size_t>& vocabulary)
341+
{
342+
std::string token;
343+
std::ifstream vocabFile(vocabPath);
344+
if (vocabFile.is_open())
345+
{
346+
for (size_t i = 0; std::getline(vocabFile, token); ++i)
347+
vocabulary[token] = i;
348+
349+
vocabFile.close();
350+
}
351+
352+
else
353+
std::cout << "Unable to open vocabulary file!" << std::endl;
354+
}
332355
};
333356
#endif

0 commit comments

Comments
 (0)