Skip to content

Commit 8831d94

Browse files
committed
Updated package v4.2
1 parent 6234227 commit 8831d94

File tree

3 files changed

+62
-39
lines changed

3 files changed

+62
-39
lines changed

locallab/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
LocalLab: Run LLMs locally with a friendly API similar to OpenAI
33
"""
44

5-
__version__ = "0.4.1"
5+
__version__ = "0.4.2"
66

77
from typing import Dict, Any, Optional
88

locallab/model_manager.py

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ def __init__(self):
6363
def _get_quantization_config(self) -> Optional[Dict[str, Any]]:
6464
"""Get quantization configuration based on settings"""
6565
# Check if quantization is explicitly disabled (not just False but also '0', 'none', '')
66-
if not ENABLE_QUANTIZATION or str(ENABLE_QUANTIZATION).lower() in ('false', '0', 'none', ''):
66+
enable_quantization = os.environ.get('LOCALLAB_ENABLE_QUANTIZATION', '').lower() not in ('false', '0', 'none', '')
67+
quantization_type = os.environ.get('LOCALLAB_QUANTIZATION_TYPE', '')
68+
69+
if not enable_quantization:
6770
logger.info("Quantization is disabled, using default precision")
6871
return {
6972
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
@@ -85,15 +88,15 @@ def _get_quantization_config(self) -> Optional[Dict[str, Any]]:
8588
}
8689

8790
# Check for empty quantization type
88-
if not QUANTIZATION_TYPE or QUANTIZATION_TYPE.lower() in ('none', ''):
91+
if not quantization_type or quantization_type.lower() in ('none', ''):
8992
logger.info(
9093
"No quantization type specified, defaulting to fp16")
9194
return {
9295
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
9396
"device_map": "auto"
9497
}
9598

96-
if QUANTIZATION_TYPE == "int8":
99+
if quantization_type == "int8":
97100
logger.info("Using INT8 quantization")
98101
return {
99102
"device_map": "auto",
@@ -104,7 +107,7 @@ def _get_quantization_config(self) -> Optional[Dict[str, Any]]:
104107
bnb_8bit_use_double_quant=True
105108
)
106109
}
107-
elif QUANTIZATION_TYPE == "int4":
110+
elif quantization_type == "int4":
108111
logger.info("Using INT4 quantization")
109112
return {
110113
"device_map": "auto",
@@ -116,7 +119,7 @@ def _get_quantization_config(self) -> Optional[Dict[str, Any]]:
116119
)
117120
}
118121
else:
119-
logger.info(f"Unrecognized quantization type '{QUANTIZATION_TYPE}', defaulting to fp16")
122+
logger.info(f"Unrecognized quantization type '{quantization_type}', defaulting to fp16")
120123
return {
121124
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
122125
"device_map": "auto"
@@ -131,17 +134,6 @@ def _get_quantization_config(self) -> Optional[Dict[str, Any]]:
131134
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
132135
"device_map": "auto"
133136
}
134-
except Exception as e:
135-
logger.warning(f"Error configuring quantization: {str(e)}. Falling back to fp16.")
136-
return {
137-
"torch_dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
138-
"device_map": "auto"
139-
}
140-
141-
return {
142-
"torch_dtype": torch.float16,
143-
"device_map": "auto"
144-
}
145137

146138
def _apply_optimizations(self, model: AutoModelForCausalLM) -> AutoModelForCausalLM:
147139
"""Apply various optimizations to the model"""
@@ -212,8 +204,12 @@ async def load_model(self, model_id: str) -> bool:
212204
hf_token = os.getenv("HF_TOKEN")
213205
config = self._get_quantization_config()
214206

207+
# Check quantization settings from environment variables
208+
enable_quantization = os.environ.get('LOCALLAB_ENABLE_QUANTIZATION', '').lower() not in ('false', '0', 'none', '')
209+
quantization_type = os.environ.get('LOCALLAB_QUANTIZATION_TYPE', '') if enable_quantization else "None"
210+
215211
if config and config.get("quantization_config"):
216-
logger.info(f"Using quantization config: {QUANTIZATION_TYPE}")
212+
logger.info(f"Using quantization config: {quantization_type}")
217213
else:
218214
precision = "fp16" if torch.cuda.is_available() else "fp32"
219215
logger.info(f"Using {precision} precision (no quantization)")
@@ -232,18 +228,14 @@ async def load_model(self, model_id: str) -> bool:
232228
**config
233229
)
234230

235-
# Move model only if quantization is disabled
236-
if not ENABLE_QUANTIZATION or str(ENABLE_QUANTIZATION).lower() in ('false', '0', 'none', ''):
237-
device = "cuda" if torch.cuda.is_available() else "cpu"
238-
logger.info(f"Moving model to {device}")
239-
self.model = AutoModelForCausalLM.from_pretrained(
240-
model_id,
241-
trust_remote_code=True,
242-
token=hf_token,
243-
device_map="auto"
244-
)
231+
# Check if the model has offloaded modules
232+
if hasattr(self.model, 'is_offloaded') and self.model.is_offloaded:
233+
logger.warning("Model has offloaded modules; skipping device move.")
245234
else:
246-
logger.info("Skipping device move for quantized model - using device_map='auto'")
235+
# Move model to the appropriate device only if quantization is disabled
236+
if not enable_quantization:
237+
device = "cuda" if torch.cuda.is_available() else "cpu"
238+
self.model = self.model.to(device)
247239

248240
# Capture model parameters after loading
249241
model_architecture = self.model.config.architectures[0] if hasattr(self.model.config, 'architectures') else 'Unknown'
@@ -374,8 +366,13 @@ async def generate(
374366
logger.warning(f"Invalid repetition_penalty value: {repetition_penalty}. Using model default.")
375367

376368
inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
369+
370+
# Get the actual device of the model
371+
model_device = next(self.model.parameters()).device
372+
373+
# Move inputs to the same device as the model
377374
for key in inputs:
378-
inputs[key] = inputs[key].to(self.device)
375+
inputs[key] = inputs[key].to(model_device)
379376

380377
if stream:
381378
return self.async_stream_generate(inputs, gen_params)
@@ -436,6 +433,14 @@ def _stream_generate(
436433
top_k = 50
437434
repetition_penalty = 1.1
438435

436+
# Get the actual device of the model
437+
model_device = next(self.model.parameters()).device
438+
439+
# Ensure inputs are on the same device as the model
440+
for key in inputs:
441+
if inputs[key].device != model_device:
442+
inputs[key] = inputs[key].to(model_device)
443+
439444
with torch.no_grad():
440445
for _ in range(max_length):
441446
generate_params = {
@@ -463,6 +468,11 @@ def _stream_generate(
463468
yield new_token
464469
inputs = {"input_ids": outputs,
465470
"attention_mask": torch.ones_like(outputs)}
471+
472+
# Ensure the updated inputs are on the correct device
473+
for key in inputs:
474+
if inputs[key].device != model_device:
475+
inputs[key] = inputs[key].to(model_device)
466476

467477
except Exception as e:
468478
logger.error(f"Streaming generation failed: {str(e)}")
@@ -500,8 +510,13 @@ async def async_stream_generate(self, inputs: Dict[str, torch.Tensor] = None, ge
500510

501511
# Tokenize the prompt
502512
inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
513+
514+
# Get the actual device of the model
515+
model_device = next(self.model.parameters()).device
516+
517+
# Move inputs to the same device as the model
503518
for key in inputs:
504-
inputs[key] = inputs[key].to(self.device)
519+
inputs[key] = inputs[key].to(model_device)
505520

506521
# Now stream tokens using the prepared inputs and parameters
507522
for token in self._stream_generate(inputs, gen_params=gen_params):
@@ -528,6 +543,14 @@ def get_model_info(self) -> Dict[str, Any]:
528543
vram_required = self.model_config.get("vram", "Unknown") if isinstance(
529544
self.model_config, dict) else "Unknown"
530545

546+
# Check environment variables for optimization settings
547+
enable_quantization = os.environ.get('LOCALLAB_ENABLE_QUANTIZATION', '').lower() not in ('false', '0', 'none', '')
548+
quantization_type = os.environ.get('LOCALLAB_QUANTIZATION_TYPE', '') if enable_quantization else "None"
549+
550+
# If quantization is disabled or type is empty, set to "None"
551+
if not enable_quantization or not quantization_type or quantization_type.lower() in ('none', ''):
552+
quantization_type = "None"
553+
531554
model_info = {
532555
"model_id": self.current_model,
533556
"model_name": model_name,
@@ -538,11 +561,11 @@ def get_model_info(self) -> Dict[str, Any]:
538561
"ram_required": ram_required,
539562
"vram_required": vram_required,
540563
"memory_used": f"{memory_used / (1024 * 1024):.2f} MB",
541-
"quantization": QUANTIZATION_TYPE if ENABLE_QUANTIZATION else "None",
564+
"quantization": quantization_type,
542565
"optimizations": {
543-
"attention_slicing": ENABLE_ATTENTION_SLICING,
544-
"flash_attention": ENABLE_FLASH_ATTENTION,
545-
"better_transformer": ENABLE_BETTERTRANSFORMER
566+
"attention_slicing": os.environ.get('LOCALLAB_ENABLE_ATTENTION_SLICING', '').lower() not in ('false', '0', 'none', ''),
567+
"flash_attention": os.environ.get('LOCALLAB_ENABLE_FLASH_ATTENTION', '').lower() not in ('false', '0', 'none', ''),
568+
"better_transformer": os.environ.get('LOCALLAB_ENABLE_BETTERTRANSFORMER', '').lower() not in ('false', '0', 'none', '')
546569
}
547570
}
548571

@@ -560,9 +583,9 @@ def get_model_info(self) -> Dict[str, Any]:
560583
• Quantization: {Fore.YELLOW}{model_info['quantization']}{Style.RESET_ALL}
561584
562585
{Fore.GREEN}🔧 Optimizations{Style.RESET_ALL}
563-
• Attention Slicing: {Fore.YELLOW}{str(ENABLE_ATTENTION_SLICING)}{Style.RESET_ALL}
564-
• Flash Attention: {Fore.YELLOW}{str(ENABLE_FLASH_ATTENTION)}{Style.RESET_ALL}
565-
• Better Transformer: {Fore.YELLOW}{str(ENABLE_BETTERTRANSFORMER)}{Style.RESET_ALL}
586+
• Attention Slicing: {Fore.YELLOW}{str(model_info['optimizations']['attention_slicing'])}{Style.RESET_ALL}
587+
• Flash Attention: {Fore.YELLOW}{str(model_info['optimizations']['flash_attention'])}{Style.RESET_ALL}
588+
• Better Transformer: {Fore.YELLOW}{str(model_info['optimizations']['better_transformer'])}{Style.RESET_ALL}
566589
""")
567590

568591
return model_info

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name="locallab",
8-
version="0.4.1",
8+
version="0.4.2",
99
packages=find_packages(include=["locallab", "locallab.*"]),
1010
install_requires=[
1111
"fastapi>=0.95.0,<1.0.0",

0 commit comments

Comments
 (0)