55import importlib
66from pathlib import Path
77
8- from transformers import AutoTokenizer , AutoModelForCausalLM , AutoConfig
8+ from transformers import AutoTokenizer , AutoModelForCausalLM , AutoModelForImageTextToText , AutoConfig
99import torch
1010import numpy as np
1111
@@ -116,11 +116,11 @@ def debug_hook(name):
116116 def fn (_m , input , output ):
117117 if isinstance (input , torch .Tensor ):
118118 summarize (input , name + "_in" )
119- elif isinstance (input , (tuple , list )) and isinstance (input [0 ], torch .Tensor ):
119+ elif isinstance (input , (tuple , list )) and len ( input ) > 0 and isinstance (input [0 ], torch .Tensor ):
120120 summarize (input [0 ], name + "_in" )
121121 if isinstance (output , torch .Tensor ):
122122 summarize (output , name + "_out" )
123- elif isinstance (output , (tuple , list )) and isinstance (output [0 ], torch .Tensor ):
123+ elif isinstance (output , (tuple , list )) and len ( output ) > 0 and isinstance (output [0 ], torch .Tensor ):
124124 summarize (output [0 ], name + "_out" )
125125
126126 return fn
@@ -130,6 +130,7 @@ def fn(_m, input, output):
130130
131131parser = argparse .ArgumentParser (description = "Process model with specified path" )
132132parser .add_argument ("--model-path" , "-m" , help = "Path to the model" )
133+ parser .add_argument ("--prompt-file" , "-f" , help = "Optional prompt file" , required = False )
133134args = parser .parse_args ()
134135
135136model_path = os .environ .get ("MODEL_PATH" , args .model_path )
@@ -142,8 +143,13 @@ def fn(_m, input, output):
142143print ("Loading model and tokenizer using AutoTokenizer:" , model_path )
143144tokenizer = AutoTokenizer .from_pretrained (model_path , trust_remote_code = True )
144145config = AutoConfig .from_pretrained (model_path , trust_remote_code = True )
146+ multimodal = False
147+ full_config = config
145148
146149print ("Model type: " , config .model_type )
150+ if "vocab_size" not in config and "text_config" in config :
151+ config = config .text_config
152+ multimodal = True
147153print ("Vocab size: " , config .vocab_size )
148154print ("Hidden size: " , config .hidden_size )
149155print ("Number of layers: " , config .num_hidden_layers )
@@ -169,9 +175,14 @@ def fn(_m, input, output):
169175 print (f"Failed to import or load model: { e } " )
170176 exit (1 )
171177else :
172- model = AutoModelForCausalLM .from_pretrained (
173- model_path , device_map = "auto" , offload_folder = "offload" , trust_remote_code = True , config = config
174- )
178+ if multimodal :
179+ model = AutoModelForImageTextToText .from_pretrained (
180+ model_path , device_map = "auto" , offload_folder = "offload" , trust_remote_code = True , config = full_config
181+ )
182+ else :
183+ model = AutoModelForCausalLM .from_pretrained (
184+ model_path , device_map = "auto" , offload_folder = "offload" , trust_remote_code = True , config = config
185+ )
175186
176187for name , module in model .named_modules ():
177188 if len (list (module .children ())) == 0 : # only leaf modules
@@ -185,7 +196,10 @@ def fn(_m, input, output):
185196print (f"Model class: { model .__class__ .__name__ } " )
186197
187198device = next (model .parameters ()).device
188- if os .getenv ("MODEL_TESTING_PROMPT" ):
199+ if args .prompt_file :
200+ with open (args .prompt_file , encoding = 'utf-8' ) as f :
201+ prompt = f .read ()
202+ elif os .getenv ("MODEL_TESTING_PROMPT" ):
189203 prompt = os .getenv ("MODEL_TESTING_PROMPT" )
190204else :
191205 prompt = "Hello, my name is"
@@ -195,9 +209,18 @@ def fn(_m, input, output):
195209print (f"Input text: { repr (prompt )} " )
196210print (f"Tokenized: { tokenizer .convert_ids_to_tokens (input_ids [0 ])} " )
197211
212+ batch_size = 512
213+
198214with torch .no_grad ():
199- outputs = model (input_ids .to (model .device ))
200- logits = outputs .logits
215+ past = None
216+ outputs = None
217+ for i in range (0 , input_ids .size (1 ), batch_size ):
218+ print (f"Processing chunk with tokens { i } to { i + batch_size } " )
219+ chunk = input_ids [:, i :i + batch_size ]
220+ outputs = model (chunk .to (model .device ), past_key_values = past , use_cache = True )
221+ past = outputs .past_key_values
222+
223+ logits = outputs .logits # type: ignore
201224
202225 # Extract logits for the last token (next token prediction)
203226 last_logits = logits [0 , - 1 , :].float ().cpu ().numpy ()
0 commit comments