Skip to content

Commit a602128

Browse files
committed
add model load code for lora
1 parent 78b428f commit a602128

File tree

3 files changed

+26
-9
lines changed

3 files changed

+26
-9
lines changed

.vscode/launch.json

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,12 @@
9090
// "request": "launch",
9191
// "program": "docs/LLaVA_OneVision_Tutorials.py",
9292
// "console": "integratedTerminal",
93-
// "env":{"CUDA_VISIBLE_DEVICES":"0",
94-
// "LD_PRELOAD": "/usr/lib/x86_64-linux-gnu/libffi.so.7"},
93+
// "env":{
94+
// "CUDA_VISIBLE_DEVICES":"0",
95+
// // "HF_HOME": "/mnt/SV_storage/VFM/huggingface",
96+
// // "LD_PRELOAD": "/usr/lib/x86_64-linux-gnu/libffi.so.7"
97+
// },
9598
// "justMyCode": false,
96-
// // "args": [
97-
// // "--run_dir_name", "test",
98-
// // // "--use_big_decoder"
99-
// // ]
10099
// }
101100
// ]
102101
// }

docs/LLaVA_OneVision_Tutorials.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,15 @@
7373

7474
warnings.filterwarnings("ignore")
7575
# Load the OneVision model
76-
pretrained = "lmms-lab/llava-onevision-qwen2-7b-ov"
77-
model_name = "llava_qwen"
76+
# pretrained = "/mnt/SV_storage/VFM/huggingface/hub/models--lmms-lab--llava-onevision-qwen2-0.5b-ov/snapshots/381d9947148efb1e58a577f451c05705ceec666e"
77+
# pretrained = "/mnt/SV_storage/VFM/LLaVA-NeXT/experiments/EK100_quick_config"
78+
# model_base = None
79+
pretrained = "/mnt/SV_storage/VFM/LLaVA-NeXT/experiments/EK100_lora_quick_check"
80+
model_base = "/mnt/SV_storage/VFM/huggingface/hub/models--lmms-lab--llava-onevision-qwen2-0.5b-ov/snapshots/381d9947148efb1e58a577f451c05705ceec666e"
81+
model_name = "lora_llava_qwen"
7882
device = "cuda"
7983
device_map = "auto"
80-
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map, attn_implementation="sdpa")
84+
tokenizer, model, image_processor, max_length = load_pretrained_model(pretrained, model_base, model_name, device_map=device_map, attn_implementation="sdpa")
8185

8286
model.eval()
8387

llava/model/builder.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,20 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
7171
lora_cfg_pretrained = LlavaMistralConfig.from_pretrained(model_path)
7272
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
7373
model = LlavaMistralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, attn_implementation=attn_implementation, **kwargs)
74+
75+
elif "qwen" in model_name.lower() or "quyen" in model_name.lower():
76+
77+
tokenizer = AutoTokenizer.from_pretrained(model_base)
78+
if "moe" in model_name.lower() or "A14B" in model_name.lower():
79+
from llava.model.language_model.llava_qwen_moe import LlavaQwenMoeConfig
80+
lora_cfg_pretrained = LlavaQwenMoeConfig.from_pretrained(model_path)
81+
model = LlavaQwenMoeForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=lora_cfg_pretrained, **kwargs)
82+
else:
83+
from llava.model.language_model.llava_qwen import LlavaQwenConfig
84+
lora_cfg_pretrained = LlavaQwenConfig.from_pretrained(model_path)
85+
model = LlavaQwenForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, attn_implementation=attn_implementation, config=lora_cfg_pretrained, **kwargs)
86+
87+
7488
elif "gemma" in model_name.lower():
7589
from llava.model.language_model.llava_gemma import LlavaGemmaConfig
7690

0 commit comments

Comments
 (0)