Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions src/tokenizers.js
Original file line number Diff line number Diff line change
Expand Up @@ -2787,22 +2787,29 @@ export class PreTrainedTokenizer extends Callable {
// For single input, we just wrap in an array, and then unwrap later.
encodedTokens = [this._encode_plus(text, { text_pair, add_special_tokens, return_token_type_ids })];
}
// At this point, tokens is batched: [batch_size, tokens]
// However, array may be jagged. So, we pad to max_length

// At this point, `encodedTokens` is batched, of shape [batch_size, tokens].
// However, array may be jagged. So, we may need pad to max_length.
if (max_length === null) {
if (padding === 'max_length') {
max_length = this.model_max_length;
} else if (truncation === null) {
if (padding === true) {
console.warn(
"`max_length` is ignored when `padding: true` and there is no truncation strategy. " +
"To pad to max length, use `padding: 'max_length'`."
)
max_length = this.model_max_length;
} else {
// Calculate max length from sequences
max_length = max(encodedTokens.map(x => x.input_ids.length))[0];
}
} else {
if (!truncation) {
console.warn(`Truncation was not explicitly activated but \`max_length\` is provided a specific value, please use \`truncation=true\` to explicitly truncate examples to max length.`)
} else if (padding === false) {
console.warn("Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation: true` to explicitly truncate examples to max length.");
truncation = true;
}
}

// padding: 'max_length' doesn't require any additional calculation
// but padding: true has to calculate max_length from the sequences
if (padding === true) {
max_length = Math.min(max(encodedTokens.map(x => x.input_ids.length))[0], max_length ?? Infinity);
}

// Ensure it is less than model max length
max_length = Math.min(max_length, this.model_max_length ?? Infinity);

Expand Down
Loading