Skip to content

Commit 80a159c

Browse files
committed
Adressing feedback.
1 parent cb9738f commit 80a159c

File tree

6 files changed

+47
-43
lines changed

6 files changed

+47
-43
lines changed

+bert/+tokenizer/+internal/FullTokenizer.m

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
% Default case
9494
this.Basic = bert.tokenizer.internal.BasicTokenizer('IgnoreCase',nvp.IgnoreCase);
9595
else
96-
assert(isa(nvp.BasicTokenizer,'bert.tokenizer.internal.Tokenizer'),"BasicTokenizer must be a bert.tokenizer.internal.Tokenizer implementation.");
96+
mustBeA(nvp.BasicTokenizer,'bert.tokenizer.internal.Tokenizer');
9797
this.Basic = nvp.BasicTokenizer;
9898
end
9999
this.WordPiece = bert.tokenizer.internal.WordPieceTokenizer(vocab);
@@ -106,10 +106,9 @@
106106
% tokens = tokenize(tokenizer,text) tokenizes the input
107107
% string text using the FullTokenizer specified by tokenizer.
108108
basicToks = this.Basic.tokenize(txt);
109-
basicToksUnicode = cellfun(@textanalytics.unicode.UTF32,basicToks,UniformOutput=false);
110109
toks = cell(numel(txt),1);
111110
for i = 1:numel(txt)
112-
theseBasicToks = basicToksUnicode{i};
111+
theseBasicToks = textanalytics.unicode.UTF32(basicToks{i});
113112
theseSubToks = cell(numel(theseBasicToks),1);
114113
for j = 1:numel(theseBasicToks)
115114
theseSubToks{j} = this.WordPiece.tokenize(theseBasicToks(j));

+bert/+tokenizer/BERTTokenizer.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@
110110
ignoreCase = nvp.IgnoreCase;
111111
this.FullTokenizer = bert.tokenizer.internal.FullTokenizer(vocabFile,'IgnoreCase',ignoreCase);
112112
else
113-
assert(isa(nvp.FullTokenizer,'bert.tokenizer.internal.FullTokenizer'),"FullTokenizer must be a bert.tokenizer.internal.FullTokenizer.");
113+
mustBeA(nvp.FullTokenizer,'bert.tokenizer.internal.FullTokenizer');
114114
this.FullTokenizer = nvp.FullTokenizer;
115115
end
116116
this.PaddingCode = this.FullTokenizer.encode(this.PaddingToken);

FineTuneBERTJapanese.m

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135

136136
%%
137137
% Initialize training progress plot.
138+
% In 23a you can use trainingProgressMonitor
138139
figure
139140
C = colororder;
140141
lineLossTrain = animatedline("Color",C(2,:));

bert.m

Lines changed: 24 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,45 +23,9 @@
2323

2424
switch nvp.Model
2525
case "japanese-base"
26-
zipFilePath = bert.internal.getSupportFilePath("japanese-base", "bert-base-japanese.zip");
27-
modelDir = fullfile(fileparts(zipFilePath), "bert-base-japanese");
28-
unzip(zipFilePath, modelDir);
29-
% Build the tokenizer
30-
btok = bert.tokenizer.internal.TokenizedDocumentTokenizer("Language","ja","TokenizeMethod","mecab",IgnoreCase=false);
31-
vocabFile = fullfile(modelDir, "vocab.txt");
32-
ftok = bert.tokenizer.internal.FullTokenizer(vocabFile,BasicTokenizer=btok);
33-
tok = bert.tokenizer.BERTTokenizer(vocabFile,FullTokenizer=ftok);
34-
% Build the model
35-
params.Weights = load(fullfile(modelDir, "weights.mat"));
36-
params.Weights = dlupdate(@dlarray,params.Weights);
37-
params.Hyperparameters = struct(...
38-
NumHeads=12,...
39-
NumLayers=12,...
40-
NumContext=512,...
41-
HiddenSize=768);
42-
mdl = struct(...
43-
Tokenizer=tok,...
44-
Parameters=params);
26+
mdl = iJapaneseBERTModel("japanese-base", "bert-base-japanese.zip");
4527
case "japanese-base-wwm"
46-
zipFilePath = bert.internal.getSupportFilePath("japanese-base", "bert-base-japanese-whole-word-masking.zip");
47-
modelDir = fullfile(fileparts(zipFilePath), "bert-base-japanese-whole-word-masking");
48-
unzip(zipFilePath, modelDir);
49-
% Build the tokenizer
50-
btok = bert.tokenizer.internal.TokenizedDocumentTokenizer("Language","ja","TokenizeMethod","mecab",IgnoreCase=false);
51-
vocabFile = fullfile(modelDir, "vocab.txt");
52-
ftok = bert.tokenizer.internal.FullTokenizer(vocabFile,BasicTokenizer=btok);
53-
tok = bert.tokenizer.BERTTokenizer(vocabFile,FullTokenizer=ftok);
54-
% Build the model
55-
params.Weights = load(fullfile(modelDir, "weights.mat"));
56-
params.Weights = dlupdate(@dlarray,params.Weights);
57-
params.Hyperparameters = struct(...
58-
NumHeads=12,...
59-
NumLayers=12,...
60-
NumContext=512,...
61-
HiddenSize=768);
62-
mdl = struct(...
63-
Tokenizer=tok,...
64-
Parameters=params);
28+
mdl = iJapaneseBERTModel("japanese-base-wwm", "bert-base-japanese-whole-word-masking.zip");
6529
otherwise
6630
% Download the license file
6731
bert.internal.getSupportFilePath(nvp.Model,"bert.RIGHTS");
@@ -76,4 +40,26 @@
7640
'Tokenizer',bert.tokenizer.BERTTokenizer(vocabFile,'IgnoreCase',ignoreCase),...
7741
'Parameters',params);
7842
end
43+
end
44+
45+
function mdl = iJapaneseBERTModel(modelName, zipFileName)
46+
zipFilePath = bert.internal.getSupportFilePath(modelName, zipFileName);
47+
modelDir = fullfile(fileparts(zipFilePath), replace(zipFileName, ".zip", ""));
48+
unzip(zipFilePath, modelDir);
49+
% Build the tokenizer
50+
btok = bert.tokenizer.internal.TokenizedDocumentTokenizer("Language","ja","TokenizeMethod","mecab",IgnoreCase=false);
51+
vocabFile = fullfile(modelDir, "vocab.txt");
52+
ftok = bert.tokenizer.internal.FullTokenizer(vocabFile,BasicTokenizer=btok);
53+
tok = bert.tokenizer.BERTTokenizer(vocabFile,FullTokenizer=ftok);
54+
% Build the model
55+
params.Weights = load(fullfile(modelDir, "weights.mat"));
56+
params.Weights = dlupdate(@dlarray,params.Weights);
57+
params.Hyperparameters = struct(...
58+
NumHeads=12,...
59+
NumLayers=12,...
60+
NumContext=512,...
61+
HiddenSize=768);
62+
mdl = struct(...
63+
Tokenizer=tok,...
64+
Parameters=params);
7965
end

test/bert/tokenizer/internal/tBasicTokenizer.m

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ function canTokenize(test)
1717
act_out = tok.tokenize(str);
1818
test.verifyEqual(act_out,exp_out);
1919
end
20+
21+
function canTokenizeBatch(test)
22+
tok = bert.tokenizer.internal.BasicTokenizer();
23+
manyStrs = repmat("foo bar baz",1,20);
24+
act_out = tokenize(tok, manyStrs);
25+
exp_out = arrayfun(@(str) tokenize(tok,str),manyStrs,UniformOutput=false);
26+
exp_out = [exp_out{:}];
27+
test.verifyEqual(act_out,exp_out);
28+
end
2029

2130
function removesControlCharactersAndWhitespace(test)
2231
tok = bert.tokenizer.internal.BasicTokenizer();

test/tbert.m

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
11
classdef(SharedTestFixtures = {
2-
DownloadBERTFixture}) tbert < matlab.unittest.TestCase
2+
DownloadBERTFixture, DownloadJPBERTFixture}) tbert < matlab.unittest.TestCase
33
% tbert System level tests for bert
44

55
% Copyright 2021 The MathWorks, Inc.
66

77
properties(TestParameter)
88
UncasedVersion = {"base", ...
99
"tiny"}
10+
AllModels = {"base","multilingual-cased","medium",...
11+
"small","mini","tiny","japanese-base",...
12+
"japanese-base-wwm"}
1013
end
1114

1215
methods(Test)
16+
1317
function canConstructModelWithDefault(test)
1418
% Verify the default model can be constructed.
1519
test.verifyWarningFree(@() bert());
1620
end
21+
22+
function canConstructAllModels(test, AllModels)
23+
% Verify the all available models can be constructed.
24+
test.verifyWarningFree(@() bert('Model', AllModels));
25+
end
1726

1827
function canConstructModelWithNVPAndVerifyDefault(test)
1928
% Verify the default model matches the default model.

0 commit comments

Comments
 (0)