Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 30 additions & 25 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ function replaceTensors(obj) {

/**
* Converts an array or Tensor of integers to an int64 Tensor.
* @param {Array|Tensor} items The input integers to be converted.
* @param {any[]|Tensor} items The input integers to be converted.
* @returns {Tensor} The int64 Tensor with the converted values.
* @throws {Error} If the input array is empty or the input is a batched Tensor and not all sequences have the same length.
* @private
Expand Down Expand Up @@ -1334,35 +1334,37 @@ export class PreTrainedModel extends Callable {
let { decoder_input_ids, ...model_inputs } = model_kwargs;

// Prepare input ids if the user has not defined `decoder_input_ids` manually.
if (!decoder_input_ids) {
decoder_start_token_id ??= bos_token_id;

if (this.config.model_type === 'musicgen') {
// Custom logic (TODO: move to Musicgen class)
decoder_input_ids = Array.from({
length: batch_size * this.config.decoder.num_codebooks
}, () => [decoder_start_token_id]);

} else if (Array.isArray(decoder_start_token_id)) {
if (decoder_start_token_id.length !== batch_size) {
throw new Error(
`\`decoder_start_token_id\` expcted to have length ${batch_size} but got ${decoder_start_token_id.length}`
)
if (!(decoder_input_ids instanceof Tensor)) {
if (!decoder_input_ids) {
decoder_start_token_id ??= bos_token_id;

if (this.config.model_type === 'musicgen') {
// Custom logic (TODO: move to Musicgen class)
decoder_input_ids = Array.from({
length: batch_size * this.config.decoder.num_codebooks
}, () => [decoder_start_token_id]);

} else if (Array.isArray(decoder_start_token_id)) {
if (decoder_start_token_id.length !== batch_size) {
throw new Error(
`\`decoder_start_token_id\` expcted to have length ${batch_size} but got ${decoder_start_token_id.length}`
)
}
decoder_input_ids = decoder_start_token_id;
} else {
decoder_input_ids = Array.from({
length: batch_size,
}, () => [decoder_start_token_id]);
}
decoder_input_ids = decoder_start_token_id;
} else {
} else if (!Array.isArray(decoder_input_ids[0])) {
// Correct batch size
decoder_input_ids = Array.from({
length: batch_size,
}, () => [decoder_start_token_id]);
}, () => decoder_input_ids);
}
} else if (!Array.isArray(decoder_input_ids[0])) {
// Correct batch size
decoder_input_ids = Array.from({
length: batch_size,
}, () => decoder_input_ids);
decoder_input_ids = toI64Tensor(decoder_input_ids);
}

decoder_input_ids = toI64Tensor(decoder_input_ids);
model_kwargs['decoder_attention_mask'] = ones_like(decoder_input_ids);

return { input_ids: decoder_input_ids, model_inputs };
Expand Down Expand Up @@ -3185,8 +3187,11 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
export class VisionEncoderDecoderModel extends PreTrainedModel {
main_input_name = 'pixel_values';
forward_params = [
// Encoder inputs
'pixel_values',
'input_ids',

// Decoder inpputs
'decoder_input_ids',
'encoder_hidden_states',
'past_key_values',
];
Expand Down
1 change: 0 additions & 1 deletion src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -2566,7 +2566,6 @@ export class DocumentQuestionAnsweringPipeline extends (/** @type {new (options:

/** @type {DocumentQuestionAnsweringPipelineCallback} */
async _call(image, question, generate_kwargs = {}) {
throw new Error('This pipeline is not yet supported in Transformers.js v3.'); // TODO: Remove when implemented

// NOTE: For now, we only support a batch size of 1

Expand Down
2 changes: 2 additions & 0 deletions src/processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,7 @@ export class DonutFeatureExtractor extends ImageFeatureExtractor {
});
}
}
export class DonutImageProcessor extends DonutFeatureExtractor { } // NOTE extends DonutFeatureExtractor
export class NougatImageProcessor extends DonutFeatureExtractor { } // NOTE extends DonutFeatureExtractor

/**
Expand Down Expand Up @@ -2569,6 +2570,7 @@ export class AutoProcessor {
MaskFormerFeatureExtractor,
YolosFeatureExtractor,
DonutFeatureExtractor,
DonutImageProcessor,
NougatImageProcessor,
EfficientNetImageProcessor,

Expand Down
Loading
Loading