Skip to content

Commit e8b6f9c

Browse files
authored
Merge pull request #52 from Lornatang/merge_model
refactor(merge_model): Enhanced file implementation robustness
2 parents 42c6234 + ffe4a27 commit e8b6f9c

File tree

1 file changed

+9
-21
lines changed

1 file changed

+9
-21
lines changed

ds/merge_model.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@ def create_test_image():
3737

3838
def load_empty_model(llm_path):
3939
print("Loading tokenizer and processor from Qwen2.5-VL and empty model...")
40-
tokenizer = Qwen2Tokenizer.from_pretrained('Qwen/Qwen2.5-VL-7B-Instruct', trust_remote_code=True, device_map={"": f"cuda:{CUDA_DEVICE}"})
41-
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
40+
tokenizer = Qwen2Tokenizer.from_pretrained('Qwen/Qwen2.5-VL-7B-Instruct', trust_remote_code=True, device_map={"": f"cuda:{CUDA_DEVICE}"}, use_fast=True)
41+
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=True)
4242
processor.image_processor.temporal_patch_size = 1
4343
processor.image_processor.max_pixels = 1600*1600
4444
llava_ov_config = Llavaonevision1_5Config()
45-
llm_config = AutoConfig.from_pretrained(llm_path, trust_remote_code=True)
45+
llm_config = AutoConfig.from_pretrained(llm_path, trust_remote_code=True, use_fast=True)
4646
llava_ov_config.text_config.update(llm_config.to_dict())
4747
llava_ov_config.vision_config.text_hidden_size = llava_ov_config.text_config.hidden_size
4848
model = LLaVAOneVision1_5_ForConditionalGeneration(llava_ov_config)
@@ -199,7 +199,6 @@ def load_llm_weights(model, llm_path, cur_len):
199199
else:
200200
cache_path = snapshot_download(llm_path, allow_patterns="*.safetensors")
201201

202-
203202
llm_weights = {}
204203
if os.path.isdir(cache_path):
205204
for filename in os.listdir(cache_path):
@@ -215,7 +214,6 @@ def load_llm_weights(model, llm_path, cur_len):
215214
llm_weights = llm_weights["state_dict"]
216215

217216
loaded_keys = 0
218-
total_keys = 0
219217

220218
ADAPTER_KEYS_TO_MODIFY_MAPPING = {
221219
"model.": "model.language_model.",
@@ -237,9 +235,7 @@ def convert_state_dict(state_dict):
237235
llm_weights['lm_head.weight'] = llm_weights['model.language_model.embed_tokens.weight']
238236
llm_keys = len(set(llm_weights.keys()))
239237

240-
241238
model_state_dict = model.state_dict()
242-
total_keys = len(model_state_dict.keys())
243239
for llm_key in llm_weights:
244240
if llm_key not in model_state_dict:
245241
logger.warning(f"LLM key {llm_key} not found in model, skipping...")
@@ -264,9 +260,8 @@ def validate_vit_consistency(model, vit_path, img_path):
264260
sample_image = Image.open(BytesIO(response.content)).convert("RGB")
265261
sample_image = sample_image.resize((560, 560))
266262

267-
268-
rice_model = MLCDVisionModel.from_pretrained(vit_path, device_map={"": f"cuda:{CUDA_DEVICE}"}, torch_dtype=torch.float32)
269-
processor = CLIPImageProcessor.from_pretrained(vit_path, device_map={"": f"cuda:{CUDA_DEVICE}"}, torch_dtype=torch.float32)
263+
rice_model = MLCDVisionModel.from_pretrained(vit_path, device_map={"": f"cuda:{CUDA_DEVICE}"}, dtype=torch.float32)
264+
processor = CLIPImageProcessor.from_pretrained(vit_path, device_map={"": f"cuda:{CUDA_DEVICE}"}, dtype=torch.float32, use_fast=True)
270265
rice_inputs = processor.preprocess(images=sample_image, return_tensors="pt").to(dtype=model.dtype, device=rice_model.device)
271266

272267
rice_model = rice_model.eval()
@@ -285,16 +280,13 @@ def spatial_reorder(tensor):
285280
output = spatial_reorder(output)
286281
reord_output_list.append(output)
287282
rice_vit_features = reord_output_list[-1]
288-
289-
290283

291284
image_grid_thw = torch.tensor([[1, 40, 40]], device=model.device, dtype=torch.long)
292285
image_processor = Qwen2VLImageProcessor()
293286
image_processor.temporal_patch_size=1
294287
processed_image = image_processor(sample_image, return_tensors="pt")
295288
with torch.no_grad():
296289
merged_output = model.visual(processed_image['pixel_values'].to(device=model.device,dtype=model.dtype), grid_thw=image_grid_thw, is_verifying=True)
297-
298290

299291
if isinstance(merged_output, torch.Tensor) and isinstance(rice_vit_features, torch.Tensor):
300292
diff = (merged_output - rice_vit_features).abs().mean().item()
@@ -316,9 +308,8 @@ def validate_llm_consistency(model, llm_path, sample_text):
316308
print("Verifying consistency of LLM component...")
317309

318310
# Load original LLM model
319-
320311
original_llm = AutoModelForCausalLM.from_pretrained(llm_path).to(dtype=model.dtype, device=model.device)
321-
tokenizer = AutoTokenizer.from_pretrained(llm_path)
312+
tokenizer = AutoTokenizer.from_pretrained(llm_path, use_fast=True)
322313

323314
# Prepare sample text
324315
inputs = tokenizer(sample_text, return_tensors="pt").to(model.device)
@@ -365,11 +356,8 @@ def main(args):
365356
adapter_path = args.adapter_path
366357
llm_path = args.llm_path
367358
output_path = args.output_path
368-
369-
# test data
370-
img_path = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
371-
sample_text = "Hello, my dog is cute"
372-
359+
img_path = args.img_path
360+
sample_text = args.sample_text
373361

374362
# 1. load empty model
375363
model, processor, tokenizer = load_empty_model(llm_path)
@@ -408,4 +396,4 @@ def main(args):
408396
parser.add_argument("--img_path", type=str, default="https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", help="Path to the image file")
409397
parser.add_argument("--sample_text", type=str, default="Hello, my dog is cute", help="Sample text for LLM consistency check")
410398
args = parser.parse_args()
411-
main(args)
399+
main(args)

0 commit comments

Comments
 (0)