|
9 | 9 | Llava15ChatHandler, |
10 | 10 | Llava16ChatHandler, |
11 | 11 | ) |
| 12 | +from huggingface_hub import hf_hub_download |
| 13 | +from tqdm import tqdm |
12 | 14 |
|
13 | 15 | from llama_assistant import config |
14 | 16 | from llama_assistant.agent import RAGAgent |
@@ -86,54 +88,88 @@ def load_agent( |
86 | 88 | if model.is_online(): |
87 | 89 | if model.model_type == "text" or model.model_type == "text-reasoning": |
88 | 90 | print("load online model") |
89 | | - loaded_model = Llama.from_pretrained( |
| 91 | + # Download with progress bar |
| 92 | + model_path = hf_hub_download( |
90 | 93 | repo_id=model.repo_id, |
91 | 94 | filename=model.filename, |
| 95 | + resume_download=True, |
| 96 | + tqdm_class=tqdm, |
| 97 | + ) |
| 98 | + loaded_model = Llama( |
| 99 | + model_path=model_path, |
92 | 100 | n_gpu_layers=-1, |
93 | 101 | n_ctx=generation_setting["context_len"], |
94 | 102 | ) |
95 | 103 | elif model.model_type == "image": |
96 | 104 | if "moondream2" in model.model_id: |
| 105 | + print("Downloading vision model projector...") |
97 | 106 | chat_handler = MoondreamChatHandler.from_pretrained( |
98 | 107 | repo_id="vikhyatk/moondream2", |
99 | 108 | filename="*mmproj*", |
100 | 109 | ) |
101 | | - loaded_model = Llama.from_pretrained( |
| 110 | + print("Downloading main model...") |
| 111 | + model_path = hf_hub_download( |
102 | 112 | repo_id=model.repo_id, |
103 | 113 | filename=model.filename, |
| 114 | + resume_download=True, |
| 115 | + tqdm_class=tqdm, |
| 116 | + ) |
| 117 | + loaded_model = Llama( |
| 118 | + model_path=model_path, |
104 | 119 | chat_handler=chat_handler, |
105 | 120 | n_ctx=generation_setting["context_len"], |
106 | 121 | ) |
107 | 122 | elif "MiniCPM" in model.model_id: |
| 123 | + print("Downloading vision model projector...") |
108 | 124 | chat_handler = MiniCPMv26ChatHandler.from_pretrained( |
109 | 125 | repo_id=model.repo_id, |
110 | 126 | filename="*mmproj*", |
111 | 127 | ) |
112 | | - loaded_model = Llama.from_pretrained( |
| 128 | + print("Downloading main model...") |
| 129 | + model_path = hf_hub_download( |
113 | 130 | repo_id=model.repo_id, |
114 | 131 | filename=model.filename, |
| 132 | + resume_download=True, |
| 133 | + tqdm_class=tqdm, |
| 134 | + ) |
| 135 | + loaded_model = Llama( |
| 136 | + model_path=model_path, |
115 | 137 | chat_handler=chat_handler, |
116 | 138 | n_ctx=generation_setting["context_len"], |
117 | 139 | ) |
118 | 140 | elif "llava-v1.5" in model.model_id: |
| 141 | + print("Downloading vision model projector...") |
119 | 142 | chat_handler = Llava15ChatHandler.from_pretrained( |
120 | 143 | repo_id=model.repo_id, |
121 | 144 | filename="*mmproj*", |
122 | 145 | ) |
123 | | - loaded_model = Llama.from_pretrained( |
| 146 | + print("Downloading main model...") |
| 147 | + model_path = hf_hub_download( |
124 | 148 | repo_id=model.repo_id, |
125 | 149 | filename=model.filename, |
| 150 | + resume_download=True, |
| 151 | + tqdm_class=tqdm, |
| 152 | + ) |
| 153 | + loaded_model = Llama( |
| 154 | + model_path=model_path, |
126 | 155 | chat_handler=chat_handler, |
127 | 156 | n_ctx=generation_setting["context_len"], |
128 | 157 | ) |
129 | 158 | elif "llava-v1.6" in model.model_id: |
| 159 | + print("Downloading vision model projector...") |
130 | 160 | chat_handler = Llava16ChatHandler.from_pretrained( |
131 | 161 | repo_id=model.repo_id, |
132 | 162 | filename="*mmproj*", |
133 | 163 | ) |
134 | | - loaded_model = Llama.from_pretrained( |
| 164 | + print("Downloading main model...") |
| 165 | + model_path = hf_hub_download( |
135 | 166 | repo_id=model.repo_id, |
136 | 167 | filename=model.filename, |
| 168 | + resume_download=True, |
| 169 | + tqdm_class=tqdm, |
| 170 | + ) |
| 171 | + loaded_model = Llama( |
| 172 | + model_path=model_path, |
137 | 173 | chat_handler=chat_handler, |
138 | 174 | n_ctx=generation_setting["context_len"], |
139 | 175 | ) |
|
0 commit comments