@@ -411,7 +411,7 @@ function replaceTensors(obj) {
411411
412412/**
413413 * Converts an array or Tensor of integers to an int64 Tensor.
414- * @param {Array |Tensor } items The input integers to be converted.
414+ * @param {any[] |Tensor } items The input integers to be converted.
415415 * @returns {Tensor } The int64 Tensor with the converted values.
416416 * @throws {Error } If the input array is empty or the input is a batched Tensor and not all sequences have the same length.
417417 * @private
@@ -1334,35 +1334,37 @@ export class PreTrainedModel extends Callable {
13341334 let { decoder_input_ids, ...model_inputs } = model_kwargs ;
13351335
13361336 // Prepare input ids if the user has not defined `decoder_input_ids` manually.
1337- if ( ! decoder_input_ids ) {
1338- decoder_start_token_id ??= bos_token_id ;
1339-
1340- if ( this . config . model_type === 'musicgen' ) {
1341- // Custom logic (TODO: move to Musicgen class)
1342- decoder_input_ids = Array . from ( {
1343- length : batch_size * this . config . decoder . num_codebooks
1344- } , ( ) => [ decoder_start_token_id ] ) ;
1345-
1346- } else if ( Array . isArray ( decoder_start_token_id ) ) {
1347- if ( decoder_start_token_id . length !== batch_size ) {
1348- throw new Error (
1349- `\`decoder_start_token_id\` expcted to have length ${ batch_size } but got ${ decoder_start_token_id . length } `
1350- )
1337+ if ( ! ( decoder_input_ids instanceof Tensor ) ) {
1338+ if ( ! decoder_input_ids ) {
1339+ decoder_start_token_id ??= bos_token_id ;
1340+
1341+ if ( this . config . model_type === 'musicgen' ) {
1342+ // Custom logic (TODO: move to Musicgen class)
1343+ decoder_input_ids = Array . from ( {
1344+ length : batch_size * this . config . decoder . num_codebooks
1345+ } , ( ) => [ decoder_start_token_id ] ) ;
1346+
1347+ } else if ( Array . isArray ( decoder_start_token_id ) ) {
1348+ if ( decoder_start_token_id . length !== batch_size ) {
1349+ throw new Error (
1350+ `\`decoder_start_token_id\` expcted to have length ${ batch_size } but got ${ decoder_start_token_id . length } `
1351+ )
1352+ }
1353+ decoder_input_ids = decoder_start_token_id ;
1354+ } else {
1355+ decoder_input_ids = Array . from ( {
1356+ length : batch_size ,
1357+ } , ( ) => [ decoder_start_token_id ] ) ;
13511358 }
1352- decoder_input_ids = decoder_start_token_id ;
1353- } else {
1359+ } else if ( ! Array . isArray ( decoder_input_ids [ 0 ] ) ) {
1360+ // Correct batch size
13541361 decoder_input_ids = Array . from ( {
13551362 length : batch_size ,
1356- } , ( ) => [ decoder_start_token_id ] ) ;
1363+ } , ( ) => decoder_input_ids ) ;
13571364 }
1358- } else if ( ! Array . isArray ( decoder_input_ids [ 0 ] ) ) {
1359- // Correct batch size
1360- decoder_input_ids = Array . from ( {
1361- length : batch_size ,
1362- } , ( ) => decoder_input_ids ) ;
1365+ decoder_input_ids = toI64Tensor ( decoder_input_ids ) ;
13631366 }
13641367
1365- decoder_input_ids = toI64Tensor ( decoder_input_ids ) ;
13661368 model_kwargs [ 'decoder_attention_mask' ] = ones_like ( decoder_input_ids ) ;
13671369
13681370 return { input_ids : decoder_input_ids , model_inputs } ;
@@ -3185,8 +3187,11 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel {
31853187export class VisionEncoderDecoderModel extends PreTrainedModel {
31863188 main_input_name = 'pixel_values' ;
31873189 forward_params = [
3190+ // Encoder inputs
31883191 'pixel_values' ,
3189- 'input_ids' ,
3192+
3193+ // Decoder inpputs
3194+ 'decoder_input_ids' ,
31903195 'encoder_hidden_states' ,
31913196 'past_key_values' ,
31923197 ] ;
0 commit comments