Skip to content

Commit cf0c9c1

Browse files
authored
Fix Document QA pipeline (#987)
* Fix Document QA pipeline * Add `DonutImageProcessor` * Update unit tests
1 parent e8c0f77 commit cf0c9c1

File tree

4 files changed

+754
-520
lines changed

4 files changed

+754
-520
lines changed

src/models.js

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ function replaceTensors(obj) {
411411

412412
/**
413413
* Converts an array or Tensor of integers to an int64 Tensor.
414-
* @param {Array|Tensor} items The input integers to be converted.
414+
* @param {any[]|Tensor} items The input integers to be converted.
415415
* @returns {Tensor} The int64 Tensor with the converted values.
416416
* @throws {Error} If the input array is empty or the input is a batched Tensor and not all sequences have the same length.
417417
* @private
@@ -1334,35 +1334,37 @@ export class PreTrainedModel extends Callable {
13341334
let { decoder_input_ids, ...model_inputs } = model_kwargs;
13351335

13361336
// Prepare input ids if the user has not defined `decoder_input_ids` manually.
1337-
if (!decoder_input_ids) {
1338-
decoder_start_token_id ??= bos_token_id;
1339-
1340-
if (this.config.model_type === 'musicgen') {
1341-
// Custom logic (TODO: move to Musicgen class)
1342-
decoder_input_ids = Array.from({
1343-
length: batch_size * this.config.decoder.num_codebooks
1344-
}, () => [decoder_start_token_id]);
1345-
1346-
} else if (Array.isArray(decoder_start_token_id)) {
1347-
if (decoder_start_token_id.length !== batch_size) {
1348-
throw new Error(
1349-
`\`decoder_start_token_id\` expcted to have length ${batch_size} but got ${decoder_start_token_id.length}`
1350-
)
1337+
if (!(decoder_input_ids instanceof Tensor)) {
1338+
if (!decoder_input_ids) {
1339+
decoder_start_token_id ??= bos_token_id;
1340+
1341+
if (this.config.model_type === 'musicgen') {
1342+
// Custom logic (TODO: move to Musicgen class)
1343+
decoder_input_ids = Array.from({
1344+
length: batch_size * this.config.decoder.num_codebooks
1345+
}, () => [decoder_start_token_id]);
1346+
1347+
} else if (Array.isArray(decoder_start_token_id)) {
1348+
if (decoder_start_token_id.length !== batch_size) {
1349+
throw new Error(
1350+
`\`decoder_start_token_id\` expcted to have length ${batch_size} but got ${decoder_start_token_id.length}`
1351+
)
1352+
}
1353+
decoder_input_ids = decoder_start_token_id;
1354+
} else {
1355+
decoder_input_ids = Array.from({
1356+
length: batch_size,
1357+
}, () => [decoder_start_token_id]);
13511358
}
1352-
decoder_input_ids = decoder_start_token_id;
1353-
} else {
1359+
} else if (!Array.isArray(decoder_input_ids[0])) {
1360+
// Correct batch size
13541361
decoder_input_ids = Array.from({
13551362
length: batch_size,
1356-
}, () => [decoder_start_token_id]);
1363+
}, () => decoder_input_ids);
13571364
}
1358-
} else if (!Array.isArray(decoder_input_ids[0])) {
1359-
// Correct batch size
1360-
decoder_input_ids = Array.from({
1361-
length: batch_size,
1362-
}, () => decoder_input_ids);
1365+
decoder_input_ids = toI64Tensor(decoder_input_ids);
13631366
}
13641367

1365-
decoder_input_ids = toI64Tensor(decoder_input_ids);
13661368
model_kwargs['decoder_attention_mask'] = ones_like(decoder_input_ids);
13671369

13681370
return { input_ids: decoder_input_ids, model_inputs };
@@ -3185,8 +3187,11 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
31853187
export class VisionEncoderDecoderModel extends PreTrainedModel {
31863188
main_input_name = 'pixel_values';
31873189
forward_params = [
3190+
// Encoder inputs
31883191
'pixel_values',
3189-
'input_ids',
3192+
3193+
// Decoder inpputs
3194+
'decoder_input_ids',
31903195
'encoder_hidden_states',
31913196
'past_key_values',
31923197
];

src/pipelines.js

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2566,7 +2566,6 @@ export class DocumentQuestionAnsweringPipeline extends (/** @type {new (options:
25662566

25672567
/** @type {DocumentQuestionAnsweringPipelineCallback} */
25682568
async _call(image, question, generate_kwargs = {}) {
2569-
throw new Error('This pipeline is not yet supported in Transformers.js v3.'); // TODO: Remove when implemented
25702569

25712570
// NOTE: For now, we only support a batch size of 1
25722571

src/processors.js

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,7 @@ export class DonutFeatureExtractor extends ImageFeatureExtractor {
12091209
});
12101210
}
12111211
}
1212+
export class DonutImageProcessor extends DonutFeatureExtractor { } // NOTE extends DonutFeatureExtractor
12121213
export class NougatImageProcessor extends DonutFeatureExtractor { } // NOTE extends DonutFeatureExtractor
12131214

12141215
/**
@@ -2569,6 +2570,7 @@ export class AutoProcessor {
25692570
MaskFormerFeatureExtractor,
25702571
YolosFeatureExtractor,
25712572
DonutFeatureExtractor,
2573+
DonutImageProcessor,
25722574
NougatImageProcessor,
25732575
EfficientNetImageProcessor,
25742576

0 commit comments

Comments
 (0)