1- // Licensed to the .NET Foundation under one or more agreements.
1+ // Licensed to the .NET Foundation under one or more agreements.
22// The .NET Foundation licenses this file to you under the MIT license.
33// See the LICENSE file in the project root for more information.
44
@@ -14,6 +14,91 @@ namespace Microsoft.ML.Tokenizers.Tests
1414{
1515 public class BertTokenizerTests
1616 {
17+ [ Fact ]
18+ public void TestWithLowerCasingExplicitSpecialTokens ( )
19+ {
20+ // Add [SPECIAL] token at end (to keep indices as is)
21+ // Ids: 0 1 2 3 4 5 6 7 8 9 10 11 12, 13
22+ string [ ] vocabTokens = [ "[PAD]" , "[UNK]" , "[CLS]" , "[SEP]" , "[MASK]" , "!" , "," , "?" , "hello" , "world" , "how" , "are" , "you" , "[SPECIAL]" ] ;
23+
24+ string vocabFile = WordPieceTests . CreateVocabFile ( vocabTokens ) ;
25+
26+ Dictionary < string , int > specialTokens = new ( ) {
27+ { "[PAD]" , 0 } ,
28+ { "[UNK]" , 1 } ,
29+ { "[CLS]" , 2 } ,
30+ { "[SEP]" , 3 } ,
31+ { "[MASK]" , 4 } ,
32+ { "[SPECIAL]" , 13 } ,
33+ } ;
34+ var bertOptions = new BertOptions ( )
35+ {
36+ SpecialTokens = specialTokens
37+ } ;
38+
39+ try
40+ {
41+ using Stream vocabStream = File . OpenRead ( vocabFile ) ;
42+ BertTokenizer [ ] bertTokenizers = [ BertTokenizer . Create ( vocabFile , bertOptions ) , BertTokenizer . Create ( vocabStream , bertOptions ) ] ;
43+
44+ foreach ( var tokenizer in bertTokenizers )
45+ {
46+ Assert . NotNull ( tokenizer . PreTokenizer ) ;
47+ Assert . Equal ( "[UNK]" , tokenizer . UnknownToken ) ;
48+ Assert . Equal ( 1 , tokenizer . UnknownTokenId ) ;
49+ Assert . NotNull ( tokenizer . Normalizer ) ;
50+ Assert . NotNull ( tokenizer . PreTokenizer ) ;
51+
52+ Assert . True ( tokenizer . SpecialTokens ! . ContainsKey ( "[SPECIAL]" ) ) ;
53+
54+ string text = "Hello, How are you [SPECIAL]?" ;
55+ var tokens = tokenizer . EncodeToTokens ( text , out string ? normalizedText ) ;
56+ Assert . Equal ( "hello, how are you [special]?" , normalizedText ) ;
57+
58+ Assert . Equal (
59+ [
60+ new EncodedToken ( 8 , "hello" , new Range ( 0 , 5 ) ) ,
61+ new EncodedToken ( 6 , "," , new Range ( 5 , 6 ) ) ,
62+ new EncodedToken ( 10 , "how" , new Range ( 7 , 10 ) ) ,
63+ new EncodedToken ( 11 , "are" , new Range ( 11 , 14 ) ) ,
64+ new EncodedToken ( 12 , "you" , new Range ( 15 , 18 ) ) ,
65+ new EncodedToken ( 13 , "[SPECIAL]" , new Range ( 19 , 28 ) ) ,
66+ new EncodedToken ( 7 , "?" , new Range ( 28 , 29 ) )
67+ ] ,
68+ tokens ) ;
69+
70+ var ids = tokenizer . EncodeToIds ( text ) ;
71+ Assert . Equal ( [ tokenizer . ClassificationTokenId , 8 , 6 , 10 , 11 , 12 , 13 , 7 , tokenizer . SeparatorTokenId ] , ids ) ;
72+
73+ Assert . Equal ( "[CLS] hello, how are you [SPECIAL]? [SEP]" , tokenizer . Decode ( ids ) ) ;
74+ Assert . Equal ( "hello, how are you?" , tokenizer . Decode ( ids , skipSpecialTokens : true ) ) ;
75+
76+ tokens = tokenizer . EncodeToTokens ( tokenizer . Decode ( ids ) , out normalizedText ) ;
77+ Assert . Equal ( "[cls] hello, how are you [special]? [sep]" , normalizedText ) ;
78+ Assert . Equal (
79+ [
80+ new EncodedToken ( 2 , "[CLS]" , new Range ( 0 , 5 ) ) ,
81+ new EncodedToken ( 8 , "hello" , new Range ( 6 , 11 ) ) ,
82+ new EncodedToken ( 6 , "," , new Range ( 11 , 12 ) ) ,
83+ new EncodedToken ( 10 , "how" , new Range ( 13 , 16 ) ) ,
84+ new EncodedToken ( 11 , "are" , new Range ( 17 , 20 ) ) ,
85+ new EncodedToken ( 12 , "you" , new Range ( 21 , 24 ) ) ,
86+ new EncodedToken ( 13 , "[SPECIAL]" , new Range ( 25 , 34 ) ) ,
87+ new EncodedToken ( 7 , "?" , new Range ( 34 , 35 ) ) ,
88+ new EncodedToken ( 3 , "[SEP]" , new Range ( 36 , 41 ) )
89+ ] ,
90+ tokens ) ;
91+
92+ ids = tokenizer . EncodeToIds ( normalizedText ! ) ;
93+ Assert . Equal ( [ tokenizer . ClassificationTokenId , tokenizer . ClassificationTokenId , 8 , 6 , 10 , 11 , 12 , 13 , 7 , tokenizer . SeparatorTokenId , tokenizer . SeparatorTokenId ] , ids ) ;
94+ }
95+ }
96+ finally
97+ {
98+ File . Delete ( vocabFile ) ;
99+ }
100+ }
101+
17102 [ Fact ]
18103 public void TestWithLowerCasing ( )
19104 {
@@ -35,6 +120,10 @@ public void TestWithLowerCasing()
35120 Assert . NotNull ( tokenizer . Normalizer ) ;
36121 Assert . NotNull ( tokenizer . PreTokenizer ) ;
37122
123+ // Make sure the SpecialTokens dictionary contains the not-normalized tokens
124+ Assert . True ( tokenizer . SpecialTokens ! . ContainsKey ( tokenizer . UnknownToken ) ) ;
125+ Assert . True ( tokenizer . SpecialTokens ! . ContainsKey ( tokenizer . ClassificationToken ) ) ;
126+
38127 string text = "Hello, How are you?" ;
39128 var tokens = tokenizer . EncodeToTokens ( text , out string ? normalizedText ) ;
40129 Assert . Equal ( "hello, how are you?" , normalizedText ) ;
@@ -511,4 +600,4 @@ public void TestCreateTokenTypeIdsFromSequences()
511600 }
512601 }
513602 }
514- }
603+ }
0 commit comments