Skip to content

Commit 7e1a0c4

Browse files
committed
Fixed LocalLab Model Download issue
1 parent 5d5cfea commit 7e1a0c4

File tree

8 files changed

+444
-90
lines changed

8 files changed

+444
-90
lines changed

locallab/core/app.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def init(backend, **kwargs):
8484
# Startup event triggered flag
8585
startup_event_triggered = False
8686

87+
# Model loading status flag
88+
model_loading_in_progress = False
89+
8790
# Application startup event to ensure banners are displayed
8891
@app.on_event("startup")
8992
async def startup_event():
@@ -155,7 +158,11 @@ async def startup_event():
155158
if model_to_load:
156159
try:
157160
# This will run asynchronously without blocking server startup
161+
# But we'll set a flag to indicate model loading is in progress
158162
asyncio.create_task(load_model_in_background(model_to_load))
163+
# Set a global flag to indicate model is loading
164+
global model_loading_in_progress
165+
model_loading_in_progress = True
159166
except Exception as e:
160167
logger.error(f"Error starting model loading task: {str(e)}")
161168
else:
@@ -288,6 +295,7 @@ async def add_process_time_header(request: Request, call_next):
288295

289296
async def load_model_in_background(model_id: str):
290297
"""Load the model asynchronously in the background"""
298+
global model_loading_in_progress
291299
logger.info(f"Loading model {model_id} in background...")
292300
start_time = time.time()
293301

@@ -309,8 +317,40 @@ async def load_model_in_background(model_id: str):
309317

310318
# We don't need to call log_model_loaded here since it's already done in the model_manager
311319
logger.info(f"{Fore.GREEN}Model {model_id} loaded successfully in {load_time:.2f} seconds!{Style.RESET_ALL}")
320+
321+
# Now that model is loaded, set server status to running
322+
from ..logger.logger import set_server_status
323+
set_server_status("running")
324+
logger.info("Server status changed to: running")
325+
326+
# Mark model loading as complete
327+
model_loading_in_progress = False
328+
329+
# Display the running banner now that model is loaded
330+
try:
331+
from ..ui.banners import print_running_banner
332+
from .. import __version__
333+
print_running_banner(__version__)
334+
except Exception as banner_e:
335+
logger.warning(f"Could not display running banner: {banner_e}")
336+
312337
except Exception as e:
313338
logger.error(f"Failed to load model {model_id}: {str(e)}")
314339
if "401 Client Error: Unauthorized" in str(e):
315340
logger.error("This appears to be an authentication error. Please ensure your HuggingFace token is set correctly.")
316-
logger.info("You can set your token using: locallab config")
341+
logger.info("You can set your token using: locallab config")
342+
343+
# Even if model loading fails, mark it as complete and set server to running
344+
# so the server can still be used for other operations
345+
model_loading_in_progress = False
346+
from ..logger.logger import set_server_status
347+
set_server_status("running")
348+
logger.info("Server status changed to: running (model loading failed)")
349+
350+
# Display the running banner even if model loading failed
351+
try:
352+
from ..ui.banners import print_running_banner
353+
from .. import __version__
354+
print_running_banner(__version__)
355+
except Exception as banner_e:
356+
logger.warning(f"Could not display running banner: {banner_e}")

locallab/model_manager.py

Lines changed: 316 additions & 53 deletions
Large diffs are not rendered by default.

locallab/routes/models.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,11 @@ class LoadModelRequest(BaseModel):
4949
async def load_model(request: LoadModelRequest) -> Dict[str, str]:
5050
"""Load a specific model"""
5151
try:
52-
# Check if model exists in registry
53-
if request.model_id not in MODEL_REGISTRY:
54-
raise HTTPException(
55-
status_code=404,
56-
detail=f"Model {request.model_id} not found. Available models: {list(MODEL_REGISTRY.keys())}"
57-
)
58-
5952
# Check if model is already loaded
6053
if model_manager.current_model == request.model_id and model_manager.is_model_loaded(request.model_id):
6154
return {"status": "success", "message": f"Model {request.model_id} is already loaded"}
62-
63-
# Load the model
55+
56+
# Load the model (this will handle both registry and custom models)
6457
await model_manager.load_model(request.model_id)
6558
return {"status": "success", "message": f"Model {request.model_id} loaded successfully"}
6659
except Exception as e:
@@ -108,17 +101,14 @@ async def get_current_model() -> ModelResponse:
108101
)
109102

110103
@router.post("/load/{model_id}", response_model=Dict[str, str])
111-
async def load_model(model_id: str, background_tasks: BackgroundTasks) -> Dict[str, str]:
104+
async def load_model_by_path(model_id: str, background_tasks: BackgroundTasks) -> Dict[str, str]:
112105
"""Load a specific model"""
113-
if model_id not in MODEL_REGISTRY:
114-
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
115-
116106
# Check if the model is already loaded
117107
if model_manager.current_model == model_id and model_manager.is_model_loaded(model_id):
118108
return {"status": "success", "message": f"Model {model_id} is already loaded"}
119-
109+
120110
try:
121-
# Load model in background
111+
# Load model in background (this will handle both registry and custom models)
122112
background_tasks.add_task(model_manager.load_model, model_id)
123113
return {"status": "loading", "message": f"Model {model_id} loading started in background"}
124114
except Exception as e:
@@ -129,15 +119,13 @@ async def load_model(model_id: str, background_tasks: BackgroundTasks) -> Dict[s
129119
async def load_model_from_body(request: LoadModelRequest, background_tasks: BackgroundTasks) -> Dict[str, str]:
130120
"""Load a specific model using model_id from request body"""
131121
model_id = request.model_id
132-
if model_id not in MODEL_REGISTRY:
133-
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
134-
122+
135123
# Check if the model is already loaded
136124
if model_manager.current_model == model_id and model_manager.is_model_loaded(model_id):
137125
return {"status": "success", "message": f"Model {model_id} is already loaded"}
138-
126+
139127
try:
140-
# Load model in background
128+
# Load model in background (this will handle both registry and custom models)
141129
background_tasks.add_task(model_manager.load_model, model_id)
142130
return {"status": "loading", "message": f"Model {model_id} loading started in background"}
143131
except Exception as e:
@@ -161,12 +149,10 @@ async def unload_model() -> Dict[str, str]:
161149
@router.get("/status/{model_id}", response_model=ModelResponse)
162150
async def get_model_status(model_id: str) -> ModelResponse:
163151
"""Get the loading status of a specific model"""
164-
if model_id not in MODEL_REGISTRY:
165-
raise HTTPException(status_code=404, detail=f"Model {model_id} not found")
166-
152+
# Check if model is in registry first, otherwise treat as custom model
167153
model_info = MODEL_REGISTRY.get(model_id, {})
168154
is_loaded = model_manager.is_model_loaded(model_id)
169-
155+
170156
# If this is the current model and it's loaded or loading
171157
if model_manager.current_model == model_id:
172158
return ModelResponse(

locallab/server.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -811,8 +811,34 @@ def on_startup():
811811
try:
812812
logger.info("Server startup callback triggered")
813813

814-
# Set server status to running
815-
set_server_status("running")
814+
# Check if a model is configured to load on startup
815+
try:
816+
from .cli.config import get_config_value
817+
from .config import DEFAULT_MODEL
818+
import os
819+
820+
# Get the model that should be loaded
821+
model_to_load = (
822+
os.environ.get("HUGGINGFACE_MODEL") or
823+
get_config_value("model") or
824+
DEFAULT_MODEL
825+
)
826+
827+
if model_to_load:
828+
# Set server status to loading while model loads
829+
set_server_status("loading")
830+
logger.info("Server status changed to: loading (waiting for model)")
831+
# Don't display running banner yet - wait for model to load
832+
return
833+
else:
834+
# No model to load, set to running immediately
835+
set_server_status("running")
836+
logger.info("Server status changed to: running")
837+
except Exception as e:
838+
# Fallback if anything fails
839+
logger.warning(f"Could not determine model loading status: {e}")
840+
set_server_status("running")
841+
logger.info("Server status changed to: running")
816842

817843
# Display the RUNNING banner
818844
print_running_banner(__version__)
@@ -862,8 +888,30 @@ def on_startup():
862888
logger.debug(f"Startup display error details: {traceback.format_exc()}")
863889
# Still mark startup as complete to avoid repeated attempts
864890
startup_complete[0] = True
865-
# Ensure server status is set to running even if display fails
866-
set_server_status("running")
891+
# Check if a model is configured to load before setting to running
892+
try:
893+
from .cli.config import get_config_value
894+
from .config import DEFAULT_MODEL
895+
import os
896+
897+
# Get the model that should be loaded
898+
model_to_load = (
899+
os.environ.get("HUGGINGFACE_MODEL") or
900+
get_config_value("model") or
901+
DEFAULT_MODEL
902+
)
903+
904+
if model_to_load:
905+
set_server_status("loading")
906+
logger.info("Server status changed to: loading (waiting for model)")
907+
else:
908+
set_server_status("running")
909+
logger.info("Server status changed to: running")
910+
except Exception as e:
911+
# Fallback if anything fails
912+
logger.warning(f"Could not determine model loading status: {e}")
913+
set_server_status("running")
914+
logger.info("Server status changed to: running")
867915

868916
# Define async callback that uvicorn can call
869917
async def on_startup_async():

locallab/utils/early_config.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,14 @@
99
import warnings
1010

1111
# Configure environment variables for Hugging Face
12-
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # Enable HF Transfer for better downloads
12+
# Only enable HF Transfer if the package is available
13+
try:
14+
import hf_transfer
15+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" # Enable HF Transfer for better downloads
16+
except ImportError:
17+
# hf_transfer not available, disable it to avoid errors
18+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
19+
1320
os.environ["TOKENIZERS_PARALLELISM"] = "true" # Enable parallelism for tokenizers
1421
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" # Disable advisory warnings
1522
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" # Disable telemetry
@@ -106,9 +113,14 @@ def enable_hf_progress_bars():
106113
# Method 3: Set environment variable (works for all versions)
107114
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "0"
108115

109-
# Also enable HF Transfer for better download experience
116+
# Also enable HF Transfer for better download experience (only if available)
110117
if hasattr(huggingface_hub, "constants"):
111-
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
118+
try:
119+
import hf_transfer
120+
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
121+
except ImportError:
122+
# hf_transfer not available, don't enable it
123+
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = False
112124
except ImportError:
113125
pass
114126

locallab/utils/progress.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,14 @@ def configure_hf_hub_progress():
164164
from huggingface_hub.utils import logging as hf_logging
165165
hf_logging.enable_progress_bars()
166166

167-
# 2. Enable HF Transfer for better download experience
168-
from huggingface_hub import constants
169-
constants.HF_HUB_ENABLE_HF_TRANSFER = True
167+
# 2. Enable HF Transfer for better download experience (only if available)
168+
try:
169+
import hf_transfer
170+
from huggingface_hub import constants
171+
constants.HF_HUB_ENABLE_HF_TRANSFER = True
172+
except ImportError:
173+
# hf_transfer not available, skip enabling it
174+
pass
170175

171176
# 3. Make sure we're NOT overriding HuggingFace's progress callback
172177
# This is critical - we want to use their native implementation

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
fastapi>=0.68.0,<1.0.0
22
uvicorn>=0.15.0,<1.0.0
33
python-multipart>=0.0.5
4-
transformers>=4.0.0
4+
transformers>=4.49.0
55
accelerate>=0.12.0
66
pyngrok>=5.1.0
77
nest-asyncio>=1.5.1

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
"python-multipart>=0.0.5",
1414
"dataclasses-json>=0.5.7,<1.0.0",
1515
"torch>=2.0.0,<3.0.0",
16-
"transformers>=4.28.1,<5.0.0",
16+
"transformers>=4.49.0,<5.0.0",
1717
"accelerate>=0.18.0,<1.0.0",
1818
"click>=8.1.3,<9.0.0",
1919
"rich>=13.3.4,<14.0.0",

0 commit comments

Comments
 (0)