@@ -61,7 +61,6 @@ import {
6161} from './utils/generic.js' ;
6262
6363import {
64- isIntegralNumber ,
6564 mergeArrays ,
6665 pick ,
6766} from './utils/core.js' ;
@@ -99,6 +98,7 @@ import {
9998
10099import {
101100 cat ,
101+ full ,
102102 full_like ,
103103 mean ,
104104 ones ,
@@ -108,6 +108,7 @@ import {
108108 Tensor ,
109109 zeros_like ,
110110} from './utils/tensor.js' ;
111+ import { RawImage } from './utils/image.js' ;
111112
112113import { dynamic_time_warping , medianFilter } from './utils/maths.js' ;
113114import { EosTokenCriteria , MaxLengthCriteria , StoppingCriteriaList } from './generation/stopping_criteria.js' ;
@@ -128,6 +129,7 @@ const MODEL_TYPES = {
128129 MaskGeneration : 5 ,
129130 ImageTextToText : 6 ,
130131 Musicgen : 7 ,
132+ MultiModality : 8 ,
131133}
132134//////////////////////////////////////////////////
133135
@@ -386,7 +388,7 @@ async function sessionRun(session, inputs) {
386388 } catch ( e ) {
387389 // This usually occurs when the inputs are of the wrong type.
388390 console . error ( `An error occurred during model execution: "${ e } ".` ) ;
389- console . error ( 'Inputs given to model:' , checkedInputs ) ;
391+ console . error ( 'Inputs given to model:' , checkedInputs )
390392 throw e ;
391393 }
392394}
@@ -716,6 +718,52 @@ function image_text_to_text_prepare_inputs_for_generation(self, ...args) {
716718 }
717719}
718720
721+ function multimodality_prepare_inputs_for_generation ( self , input_ids , model_inputs , generation_config ) {
722+ const has_past_key_values = ! ! model_inputs . past_key_values ;
723+
724+ if ( generation_config . guidance_scale !== null && generation_config . guidance_scale > 1 ) {
725+ if ( has_past_key_values ) {
726+ model_inputs . input_ids = cat ( [
727+ model_inputs . input_ids ,
728+ model_inputs . input_ids ,
729+ ] , 0 )
730+ // NOTE: attention_mask handled in generation
731+ } else {
732+ model_inputs . input_ids = cat ( [
733+ model_inputs . input_ids ,
734+ full_like ( model_inputs . input_ids , BigInt ( generation_config . pad_token_id ) ) ,
735+ ] , 0 ) ;
736+ model_inputs . attention_mask = cat ( [
737+ model_inputs . attention_mask ,
738+ full_like ( model_inputs . attention_mask , 0n ) ,
739+ ] , 0 ) ;
740+ }
741+ }
742+
743+ if ( has_past_key_values || ! model_inputs . pixel_values ) {
744+ model_inputs . pixel_values = full ( [ 0 , 0 , 3 , 384 , 384 ] , 1.0 ) ;
745+ }
746+
747+ if ( has_past_key_values ) {
748+ const num_img_tokens = 0 ;
749+ const num_text_tokens = 1 ;
750+ const has_image = num_img_tokens > 0 ? 1 : 0 ;
751+
752+ const batch_size = 1 ;
753+ model_inputs . images_seq_mask = new Tensor (
754+ 'bool' ,
755+ new Array ( num_img_tokens + num_text_tokens ) . fill ( true ) . fill ( false , 0 , num_text_tokens ) ,
756+ [ batch_size , num_img_tokens + num_text_tokens ] ,
757+ ) ;
758+ model_inputs . images_emb_mask = new Tensor (
759+ 'bool' ,
760+ new Array ( num_img_tokens ) . fill ( ! ! has_image ) ,
761+ [ batch_size , 1 , num_img_tokens ] ,
762+ ) ;
763+ }
764+ return model_inputs ;
765+ }
766+
719767//////////////////////////////////////////////////
720768
721769//////////////////////////////////////////////////
@@ -769,6 +817,11 @@ export class PreTrainedModel extends Callable {
769817 this . _prepare_inputs_for_generation = image_text_to_text_prepare_inputs_for_generation ;
770818 break ;
771819
820+ case MODEL_TYPES . MultiModality :
821+ this . can_generate = true ;
822+ this . _prepare_inputs_for_generation = multimodality_prepare_inputs_for_generation ;
823+ break ;
824+
772825 default :
773826 // should be MODEL_TYPES.EncoderOnly
774827 this . _forward = encoderForward ;
@@ -912,6 +965,21 @@ export class PreTrainedModel extends Callable {
912965 } , options ) ,
913966 ] ) ;
914967
968+ } else if ( modelType === MODEL_TYPES . MultiModality ) {
969+ info = await Promise . all ( [
970+ constructSessions ( pretrained_model_name_or_path , {
971+ prepare_inputs_embeds : 'prepare_inputs_embeds' ,
972+ model : 'language_model' ,
973+ lm_head : 'lm_head' ,
974+ gen_head : 'gen_head' ,
975+ gen_img_embeds : 'gen_img_embeds' ,
976+ image_decode : 'image_decode' ,
977+ } , options ) ,
978+ getOptionalConfigs ( pretrained_model_name_or_path , {
979+ generation_config : 'generation_config.json' ,
980+ } , options ) ,
981+ ] ) ;
982+
915983 } else { // should be MODEL_TYPES.EncoderOnly
916984 if ( modelType !== MODEL_TYPES . EncoderOnly ) {
917985 console . warn ( `Model type for '${ modelName ?? config ?. model_type } ' not found, assuming encoder-only architecture. Please report this at ${ GITHUB_ISSUE_URL } .` )
@@ -1658,7 +1726,8 @@ export class PreTrainedModel extends Callable {
16581726 const dtype = session ?. config ?. kv_cache_dtype ?? 'float32' ;
16591727 const empty = ( dtype === 'float16' ) ? new Uint16Array ( ) : [ ] ;
16601728
1661- const shapes = getKeyValueShapes ( this . config ) ;
1729+ const batch_size = decoderFeeds [ this . main_input_name ] . dims [ 0 ] ;
1730+ const shapes = getKeyValueShapes ( this . config , { batch_size } ) ;
16621731
16631732 for ( const name in shapes ) {
16641733 decoderFeeds [ name ] = new Tensor ( dtype , empty , shapes [ name ] ) ;
@@ -5954,6 +6023,111 @@ export class DecisionTransformerModel extends DecisionTransformerPreTrainedModel
59546023
59556024//////////////////////////////////////////////////
59566025
6026+ export class MultiModalityPreTrainedModel extends PreTrainedModel { }
6027+ export class MultiModalityCausalLM extends MultiModalityPreTrainedModel {
6028+ forward_params = [
6029+ // prepare_inputs_embeds
6030+ 'input_ids' ,
6031+ 'pixel_values' ,
6032+ 'images_seq_mask' ,
6033+ 'images_emb_mask' ,
6034+
6035+ // language_model
6036+ 'attention_mask' ,
6037+ 'position_ids' ,
6038+ 'past_key_values' ,
6039+ ] ;
6040+
6041+ constructor ( ...args ) {
6042+ super ( ...args ) ;
6043+
6044+ // State-based approach to switch out which heads to use during generation
6045+ this . _generation_mode = 'text' ;
6046+ }
6047+
6048+ async forward ( model_inputs ) {
6049+ const mode = this . _generation_mode ?? 'text' ;
6050+
6051+ // TODO support re-using PKVs for input_ids.dims[1] !== 1
6052+ // if (model_inputs.past_key_values) {
6053+ // // && model_inputs.input_ids.dims[1] === 1
6054+ // }
6055+
6056+ let output_1 ;
6057+ if ( mode === 'text' || ! model_inputs . past_key_values ) {
6058+ const session = this . sessions [ 'prepare_inputs_embeds' ] ;
6059+ const prep_inputs = pick ( model_inputs , session . inputNames ) ;
6060+ output_1 = await sessionRun ( session , prep_inputs ) ;
6061+ } else {
6062+ const session = this . sessions [ 'gen_img_embeds' ] ;
6063+ const prep_inputs = pick ( {
6064+ image_ids : model_inputs . input_ids ,
6065+ } , session . inputNames ) ;
6066+ output_1 = await sessionRun ( session , prep_inputs ) ;
6067+ }
6068+
6069+ const input_2 = { ...model_inputs , ...output_1 }
6070+ const output_2 = await decoderForward ( this , input_2 ) ;
6071+
6072+ const head = this . sessions [
6073+ mode === 'text'
6074+ ? 'lm_head'
6075+ : 'gen_head'
6076+ ] ;
6077+ if ( ! head ) {
6078+ throw new Error ( `Unable to find "${ head } " generation head` ) ;
6079+ }
6080+
6081+ const output_3 = await sessionRun ( head , pick ( output_2 , head . inputNames ) )
6082+
6083+ return {
6084+ ...output_1 ,
6085+ ...output_2 ,
6086+ ...output_3 ,
6087+ } ;
6088+ }
6089+
6090+ /**
6091+ * @param {import('./generation/parameters.js').GenerationFunctionParameters } options
6092+ */
6093+ async generate ( options ) {
6094+ this . _generation_mode = 'text' ;
6095+ return super . generate ( options ) ;
6096+ }
6097+
6098+ /**
6099+ * @param {import('./generation/parameters.js').GenerationFunctionParameters } options
6100+ */
6101+ async generate_images ( options ) {
6102+ this . _generation_mode = 'image' ;
6103+
6104+ const start_num_tokens = ( options . inputs ?? options [ this . main_input_name ] ) . dims [ 1 ] ;
6105+ const all_tokens = await super . generate ( options ) ;
6106+
6107+ const generated_tokens = ( /** @type {Tensor } */ ( all_tokens ) ) . slice ( null , [ start_num_tokens , null ] )
6108+
6109+ const image_decode = this . sessions [ 'image_decode' ] ;
6110+ const { decoded_image } = await sessionRun ( image_decode , {
6111+ generated_tokens,
6112+ } ) ;
6113+
6114+ // Equivalent to `np.clip((dec + 1) / 2 * 255, 0, 255)`
6115+ const clamped = decoded_image
6116+ . add_ ( 1 )
6117+ . mul_ ( 255 / 2 )
6118+ . clamp_ ( 0 , 255 )
6119+ . to ( 'uint8' ) ;
6120+
6121+ // Return as a list of images
6122+ const images = [ ] ;
6123+ for ( const tensor of clamped ) {
6124+ const img = RawImage . fromTensor ( tensor ) ;
6125+ images . push ( img ) ;
6126+ }
6127+ return images ;
6128+ }
6129+ }
6130+
59576131//////////////////////////////////////////////////
59586132// AutoModels, used to simplify construction of PreTrainedModels
59596133// (uses config to instantiate correct class)
@@ -6232,6 +6406,11 @@ const MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = new Map([
62326406 [ 'stablelm' , [ 'StableLmForCausalLM' , StableLmForCausalLM ] ] ,
62336407] ) ;
62346408
6409+ const MODEL_FOR_MULTIMODALITY_MAPPING_NAMES = new Map ( [
6410+ [ 'multi_modality' , [ 'MultiModalityCausalLM' , MultiModalityCausalLM ] ] ,
6411+ ] ) ;
6412+
6413+
62356414const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map ( [
62366415 [ 'bert' , [ 'BertForMaskedLM' , BertForMaskedLM ] ] ,
62376416 [ 'roformer' , [ 'RoFormerForMaskedLM' , RoFormerForMaskedLM ] ] ,
@@ -6404,6 +6583,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
64046583 [ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES , MODEL_TYPES . Seq2Seq ] ,
64056584 [ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES , MODEL_TYPES . Seq2Seq ] ,
64066585 [ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES , MODEL_TYPES . DecoderOnly ] ,
6586+ [ MODEL_FOR_MULTIMODALITY_MAPPING_NAMES , MODEL_TYPES . MultiModality ] ,
64076587 [ MODEL_FOR_MASKED_LM_MAPPING_NAMES , MODEL_TYPES . EncoderOnly ] ,
64086588 [ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES , MODEL_TYPES . EncoderOnly ] ,
64096589 [ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES , MODEL_TYPES . Vision2Seq ] ,
0 commit comments