@@ -290,9 +290,25 @@ public IReadOnlyList<int> BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds
290290                throw  new  ArgumentNullException ( nameof ( tokenIds0 ) ) ; 
291291            } 
292292
293-             // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null. 
294-             int  capacity  =  tokenIds0 . Count ( )  +  2  +  ( tokenIds1  is  null  ?  0  :  tokenIds1 . Count ( )  +  1 ) ; 
295-             List < int >  ids  =  new  List < int > ( capacity :  capacity )  {  ClsTokenId  } ; 
293+             List < int >  ids ; 
294+ 
295+             if  ( tokenIds0  is  ICollection < int >  c1 ) 
296+             { 
297+                 int  capacity  =  c1 . Count  +  2 ;     // Add 2 for [CLS] and two [SEP] tokens. 
298+ 
299+                 if  ( tokenIds1  is  not null ) 
300+                 { 
301+                     capacity  +=  tokenIds1  is  ICollection < int >  c2  ?  c2 . Count  +  1  :  c1 . Count  +  1 ; 
302+                 } 
303+ 
304+                 ids  =  new ( capacity )  {  ClsTokenId  } ; 
305+             } 
306+             else 
307+             { 
308+                 // slow path 
309+                 ids  =  new  List < int > ( 10 )  {  ClsTokenId  } ; 
310+             } 
311+ 
296312            ids . AddRange ( tokenIds0 ) ; 
297313            ids . Add ( SepTokenId ) ; 
298314
@@ -323,29 +339,48 @@ public OperationStatus BuildInputsWithSpecialTokens(IEnumerable<int> tokenIds0,
323339                throw  new  ArgumentNullException ( nameof ( tokenIds0 ) ) ; 
324340            } 
325341
326-             // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null. 
327-             int  capacity  =  tokenIds0 . Count ( )  +  2  +  ( tokenIds1  is  null  ?  0  :  tokenIds1 . Count ( )  +  1 ) ; 
328-             if  ( buffer . Length  <  capacity ) 
342+             written  =  0 ; 
343+             if  ( buffer . Length  <  1 ) 
329344            { 
330-                 written  =  0 ; 
331345                return  OperationStatus . DestinationTooSmall ; 
332346            } 
333347
334-             written  =  0 ; 
335348            buffer [ written ++ ]  =  ClsTokenId ; 
336349            foreach  ( int  id  in  tokenIds0 ) 
337350            { 
351+                 if  ( buffer . Length  <=  written ) 
352+                 { 
353+                     written  =  0 ; 
354+                     return  OperationStatus . DestinationTooSmall ; 
355+                 } 
356+ 
338357                buffer [ written ++ ]  =  id ; 
339358            } 
359+ 
360+             if  ( buffer . Length  <=  written ) 
361+             { 
362+                 written  =  0 ; 
363+                 return  OperationStatus . DestinationTooSmall ; 
364+             } 
340365            buffer [ written ++ ]  =  SepTokenId ; 
341366
342367            if  ( tokenIds1  is  not null ) 
343368            { 
344369                foreach  ( int  id  in  tokenIds1 ) 
345370                { 
371+                     if  ( buffer . Length  <=  written ) 
372+                     { 
373+                         written  =  0 ; 
374+                         return  OperationStatus . DestinationTooSmall ; 
375+                     } 
346376                    buffer [ written ++ ]  =  id ; 
347377                } 
348378
379+                 if  ( buffer . Length  <=  written ) 
380+                 { 
381+                     written  =  0 ; 
382+                     return  OperationStatus . DestinationTooSmall ; 
383+                 } 
349384                buffer [ written ++ ]  =  SepTokenId ; 
350385            } 
351386
@@ -367,11 +402,22 @@ public IReadOnlyList<int> GetSpecialTokensMask(IEnumerable<int> tokenIds0, IEnum
367402                throw  new  ArgumentNullException ( nameof ( tokenIds0 ) ) ; 
368403            } 
369404
370-             int  capacity  =  alreadyHasSpecialTokens  ? 
371-                         tokenIds0 . Count ( )  +  ( tokenIds1 ? . Count ( )  ??  0 )  : 
372-                         tokenIds0 . Count ( )  +  2  +  ( tokenIds1  is  null  ?  0  :  1 ) ;     // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null. 
405+             List < int >  mask ; 
406+             if  ( tokenIds0  is  ICollection < int >  c1 ) 
407+             { 
408+                 int  capcity  =  c1 . Count  +  2 ; 
409+ 
410+                 if  ( tokenIds1  is  not null ) 
411+                 { 
412+                     capcity  +=  tokenIds1  is  ICollection < int >  c2  ?  c2 . Count  +  1  :  c1 . Count  +  1 ; 
413+                 } 
373414
374-             List < int >  mask  =  new  List < int > ( capacity :  capacity ) ; 
415+                 mask  =  new  List < int > ( capcity ) ; 
416+             } 
417+             else 
418+             { 
419+                 mask  =  new  List < int > ( 10 ) ; 
420+             } 
375421
376422            if  ( ! alreadyHasSpecialTokens ) 
377423            { 
@@ -420,31 +466,49 @@ public OperationStatus GetSpecialTokensMask(IEnumerable<int> tokenIds0, Span<int
420466                throw  new  ArgumentNullException ( nameof ( tokenIds0 ) ) ; 
421467            } 
422468
423-             int  capacity  =  alreadyHasSpecialTokens  ? 
424-                         tokenIds0 . Count ( )  +  ( tokenIds1 ? . Count ( )  ??  0 )  : 
425-                         tokenIds0 . Count ( )  +  2  +  ( tokenIds1  is  null  ?  0  :  tokenIds1 . Count ( )  +  1 ) ;     // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null. 
426- 
427469            written  =  0 ; 
428-             if  ( buffer . Length  <  capacity ) 
429-             { 
430-                 return  OperationStatus . DestinationTooSmall ; 
431-             } 
432- 
433470            if  ( ! alreadyHasSpecialTokens ) 
434471            { 
472+                 if  ( buffer . Length  <  1 ) 
473+                 { 
474+                     return  OperationStatus . DestinationTooSmall ; 
475+                 } 
435476                buffer [ written ++ ]  =  1 ;  // CLS 
477+ 
436478                foreach  ( int  id  in  tokenIds0 ) 
437479                { 
480+                     if  ( buffer . Length  <=  written ) 
481+                     { 
482+                         written  =  0 ; 
483+                         return  OperationStatus . DestinationTooSmall ; 
484+                     } 
438485                    buffer [ written ++ ]  =  0 ; 
439486                } 
487+ 
488+                 if  ( buffer . Length  <=  written ) 
489+                 { 
490+                     written  =  0 ; 
491+                     return  OperationStatus . DestinationTooSmall ; 
492+                 } 
440493                buffer [ written ++ ]  =  1 ;  // SEP 
441494
442495                if  ( tokenIds1  is  not null ) 
443496                { 
444497                    foreach  ( int  id  in  tokenIds1 ) 
445498                    { 
499+                         if  ( buffer . Length  <=  written ) 
500+                         { 
501+                             written  =  0 ; 
502+                             return  OperationStatus . DestinationTooSmall ; 
503+                         } 
446504                        buffer [ written ++ ]  =  0 ; 
447505                    } 
506+ 
507+                     if  ( buffer . Length  <=  written ) 
508+                     { 
509+                         written  =  0 ; 
510+                         return  OperationStatus . DestinationTooSmall ; 
511+                     } 
448512                    buffer [ written ++ ]  =  1 ;  // SEP 
449513                } 
450514
@@ -453,13 +517,23 @@ public OperationStatus GetSpecialTokensMask(IEnumerable<int> tokenIds0, Span<int
453517
454518            foreach  ( int  id  in  tokenIds0 ) 
455519            { 
520+                 if  ( buffer . Length  <=  written ) 
521+                 { 
522+                     written  =  0 ; 
523+                     return  OperationStatus . DestinationTooSmall ; 
524+                 } 
456525                buffer [ written ++ ]  =  id  ==  ClsTokenId  ||  id  ==  SepTokenId  ||  id  ==  PadTokenId  ||  id  ==  MaskTokenId  ||  id  ==  UnknownTokenId  ?  1  :  0 ; 
457526            } 
458527
459528            if  ( tokenIds1  is  not null ) 
460529            { 
461530                foreach  ( int  id  in  tokenIds1 ) 
462531                { 
532+                     if  ( buffer . Length  <=  written ) 
533+                     { 
534+                         written  =  0 ; 
535+                         return  OperationStatus . DestinationTooSmall ; 
536+                     } 
463537                    buffer [ written ++ ]  =  id  ==  ClsTokenId  ||  id  ==  SepTokenId  ||  id  ==  PadTokenId  ||  id  ==  MaskTokenId  ||  id  ==  UnknownTokenId  ?  1  :  0 ; 
464538                } 
465539            } 
@@ -484,21 +558,38 @@ public IReadOnlyList<int> CreateTokenTypeIdsFromSequences(IEnumerable<int> token
484558                throw  new  ArgumentNullException ( nameof ( tokenIds0 ) ) ; 
485559            } 
486560
487-             // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null. 
488-             int  capacity  =  tokenIds0 . Count ( )  +  2  +  ( tokenIds1  is  null  ?  0  :  tokenIds1 . Count ( )  +  1 ) ; 
561+             List < int >  typeIds ; 
562+             if  ( tokenIds0  is  ICollection < int >  c1 ) 
563+             { 
564+                 int  capacity  =  c1 . Count  +  2 ;     // Add 2 for [CLS] and [SEP] tokens. 
565+ 
566+                 if  ( tokenIds1  is  not null ) 
567+                 { 
568+                     capacity  +=  tokenIds1  is  ICollection < int >  c2  ?  c2 . Count  +  1  :  c1 . Count  +  1 ; 
569+                 } 
489570
490-             List < int >  typeIds  =  new  List < int > ( capacity ) ; 
491-             for  ( int  i  =  0 ;  i  <  tokenIds0 . Count ( )  +  2 ;  i ++ )  // Add 2 for [CLS] and [SEP] tokens. 
571+                 typeIds  =  new  List < int > ( capacity ) ; 
572+             } 
573+             else 
574+             { 
575+                 typeIds  =  new  List < int > ( 10 ) ; 
576+             } 
577+ 
578+             foreach  ( var  id  in  tokenIds0 ) 
492579            { 
493580                typeIds . Add ( 0 ) ; 
494581            } 
582+             typeIds . Add ( 0 ) ;  // [CLS] 
583+             typeIds . Add ( 0 ) ;  // [SEP] 
495584
496585            if  ( tokenIds1  is  not null ) 
497586            { 
498-                 for  ( int  i   =   0 ;   i   <   tokenIds1 . Count ( )   +   1 ;   i ++ )   // Add 1 for [SEP] token. 
587+                 foreach  ( int  id   in   tokenIds1 ) 
499588                { 
500589                    typeIds . Add ( 1 ) ; 
501590                } 
591+ 
592+                 typeIds . Add ( 1 ) ;  // [SEP] 
502593            } 
503594
504595            return  typeIds ; 
@@ -515,22 +606,40 @@ public OperationStatus CreateTokenTypeIdsFromSequences(IEnumerable<int> tokenIds
515606
516607            // Add 2 for [CLS] and [SEP] tokens. Add 1 for [SEP] token if tokenIds1 is not null. 
517608            int  capacity  =  tokenIds0 . Count ( )  +  2  +  ( tokenIds1  is  null  ?  0  :  tokenIds1 . Count ( )  +  1 ) ; 
518-             if  ( buffer . Length  <  capacity ) 
609+             if  ( buffer . Length  <  2 ) 
519610            { 
520611                return  OperationStatus . DestinationTooSmall ; 
521612            } 
613+             buffer [ written ++ ]  =  0 ;  // [CLS] 
614+             buffer [ written ++ ]  =  0 ;  // [SEP] 
522615
523-             for  ( int  i   =   0 ;   i   <   tokenIds0 . Count ( )   +   2 ;   i ++ )   // Add 2 for [CLS] and [SEP] tokens. 
616+             foreach  ( int  id   in   tokenIds0 ) 
524617            { 
618+                 if  ( buffer . Length  <=  written ) 
619+                 { 
620+                     written  =  0 ; 
621+                     return  OperationStatus . DestinationTooSmall ; 
622+                 } 
525623                buffer [ written ++ ]  =  0 ; 
526624            } 
527625
528626            if  ( tokenIds1  is  not null ) 
529627            { 
530-                 for  ( int  i   =   0 ;   i   <   tokenIds1 . Count ( )   +   1 ;   i ++ )   // Add 1 for [SEP] token. 
628+                 foreach  ( int  id   in   tokenIds1 ) 
531629                { 
630+                     if  ( buffer . Length  <=  written ) 
631+                     { 
632+                         written  =  0 ; 
633+                         return  OperationStatus . DestinationTooSmall ; 
634+                     } 
532635                    buffer [ written ++ ]  =  1 ; 
533636                } 
637+ 
638+                 if  ( buffer . Length  <  written ) 
639+                 { 
640+                     return  OperationStatus . DestinationTooSmall ; 
641+                 } 
642+                 buffer [ written ++ ]  =  1 ;  // [SEP] 
534643            } 
535644
536645            return  OperationStatus . Done ; 
0 commit comments