Skip to content

Commit ac0096e

Browse files
do-mexenova
andauthored
Add default token_type_ids for multilingual-e5-* models (#403)
* Fix #267 & #324 Add default token_type_ids. Fix for multilingual-e5-* family. * Add add_token_types import * export `add_token_types` * Improvements --------- Co-authored-by: Joshua Lochner <[email protected]>
1 parent b8719b1 commit ac0096e

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

src/models.js

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ import {
4242
AutoConfig,
4343
} from './configs.js';
4444

45+
import {
46+
add_token_types,
47+
} from './tokenizers.js';
48+
4549
import {
4650
Callable,
4751
isIntegralNumber,
@@ -488,10 +492,15 @@ function seq2seqUpdatebeam(beam, newTokenId) {
488492
* @private
489493
*/
490494
async function encoderForward(self, model_inputs) {
491-
let encoderFeeds = {};
492-
for (let key of self.session.inputNames) {
495+
const encoderFeeds = Object.create(null);
496+
for (const key of self.session.inputNames) {
493497
encoderFeeds[key] = model_inputs[key];
494498
}
499+
if (self.session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) {
500+
// Assign default `token_type_ids` to the `encoderFeeds` if the model expects it,
501+
// but they weren't created by the tokenizer.
502+
add_token_types(encoderFeeds);
503+
}
495504
return await sessionRun(self.session, encoderFeeds);
496505
}
497506

src/tokenizers.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2578,7 +2578,7 @@ export class PreTrainedTokenizer extends Callable {
25782578
* @param {Object} inputs An object containing the input ids and attention mask.
25792579
* @returns {Object} The prepared inputs object.
25802580
*/
2581-
function add_token_types(inputs) {
2581+
export function add_token_types(inputs) {
25822582
// TODO ensure correctness when token pair is present
25832583
if (inputs.input_ids instanceof Tensor) {
25842584
inputs.token_type_ids = new Tensor(

0 commit comments

Comments
 (0)