@@ -63,7 +63,10 @@ def __init__(self):
6363 def _get_quantization_config (self ) -> Optional [Dict [str , Any ]]:
6464 """Get quantization configuration based on settings"""
6565 # Check if quantization is explicitly disabled (not just False but also '0', 'none', '')
66- if not ENABLE_QUANTIZATION or str (ENABLE_QUANTIZATION ).lower () in ('false' , '0' , 'none' , '' ):
66+ enable_quantization = os .environ .get ('LOCALLAB_ENABLE_QUANTIZATION' , '' ).lower () not in ('false' , '0' , 'none' , '' )
67+ quantization_type = os .environ .get ('LOCALLAB_QUANTIZATION_TYPE' , '' )
68+
69+ if not enable_quantization :
6770 logger .info ("Quantization is disabled, using default precision" )
6871 return {
6972 "torch_dtype" : torch .float16 if torch .cuda .is_available () else torch .float32 ,
@@ -85,15 +88,15 @@ def _get_quantization_config(self) -> Optional[Dict[str, Any]]:
8588 }
8689
8790 # Check for empty quantization type
88- if not QUANTIZATION_TYPE or QUANTIZATION_TYPE .lower () in ('none' , '' ):
91+ if not quantization_type or quantization_type .lower () in ('none' , '' ):
8992 logger .info (
9093 "No quantization type specified, defaulting to fp16" )
9194 return {
9295 "torch_dtype" : torch .float16 if torch .cuda .is_available () else torch .float32 ,
9396 "device_map" : "auto"
9497 }
9598
96- if QUANTIZATION_TYPE == "int8" :
99+ if quantization_type == "int8" :
97100 logger .info ("Using INT8 quantization" )
98101 return {
99102 "device_map" : "auto" ,
@@ -104,7 +107,7 @@ def _get_quantization_config(self) -> Optional[Dict[str, Any]]:
104107 bnb_8bit_use_double_quant = True
105108 )
106109 }
107- elif QUANTIZATION_TYPE == "int4" :
110+ elif quantization_type == "int4" :
108111 logger .info ("Using INT4 quantization" )
109112 return {
110113 "device_map" : "auto" ,
@@ -116,7 +119,7 @@ def _get_quantization_config(self) -> Optional[Dict[str, Any]]:
116119 )
117120 }
118121 else :
119- logger .info (f"Unrecognized quantization type '{ QUANTIZATION_TYPE } ', defaulting to fp16" )
122+ logger .info (f"Unrecognized quantization type '{ quantization_type } ', defaulting to fp16" )
120123 return {
121124 "torch_dtype" : torch .float16 if torch .cuda .is_available () else torch .float32 ,
122125 "device_map" : "auto"
@@ -131,17 +134,6 @@ def _get_quantization_config(self) -> Optional[Dict[str, Any]]:
131134 "torch_dtype" : torch .float16 if torch .cuda .is_available () else torch .float32 ,
132135 "device_map" : "auto"
133136 }
134- except Exception as e :
135- logger .warning (f"Error configuring quantization: { str (e )} . Falling back to fp16." )
136- return {
137- "torch_dtype" : torch .float16 if torch .cuda .is_available () else torch .float32 ,
138- "device_map" : "auto"
139- }
140-
141- return {
142- "torch_dtype" : torch .float16 ,
143- "device_map" : "auto"
144- }
145137
146138 def _apply_optimizations (self , model : AutoModelForCausalLM ) -> AutoModelForCausalLM :
147139 """Apply various optimizations to the model"""
@@ -212,8 +204,12 @@ async def load_model(self, model_id: str) -> bool:
212204 hf_token = os .getenv ("HF_TOKEN" )
213205 config = self ._get_quantization_config ()
214206
207+ # Check quantization settings from environment variables
208+ enable_quantization = os .environ .get ('LOCALLAB_ENABLE_QUANTIZATION' , '' ).lower () not in ('false' , '0' , 'none' , '' )
209+ quantization_type = os .environ .get ('LOCALLAB_QUANTIZATION_TYPE' , '' ) if enable_quantization else "None"
210+
215211 if config and config .get ("quantization_config" ):
216- logger .info (f"Using quantization config: { QUANTIZATION_TYPE } " )
212+ logger .info (f"Using quantization config: { quantization_type } " )
217213 else :
218214 precision = "fp16" if torch .cuda .is_available () else "fp32"
219215 logger .info (f"Using { precision } precision (no quantization)" )
@@ -232,18 +228,14 @@ async def load_model(self, model_id: str) -> bool:
232228 ** config
233229 )
234230
235- # Move model only if quantization is disabled
236- if not ENABLE_QUANTIZATION or str (ENABLE_QUANTIZATION ).lower () in ('false' , '0' , 'none' , '' ):
237- device = "cuda" if torch .cuda .is_available () else "cpu"
238- logger .info (f"Moving model to { device } " )
239- self .model = AutoModelForCausalLM .from_pretrained (
240- model_id ,
241- trust_remote_code = True ,
242- token = hf_token ,
243- device_map = "auto"
244- )
231+ # Check if the model has offloaded modules
232+ if hasattr (self .model , 'is_offloaded' ) and self .model .is_offloaded :
233+ logger .warning ("Model has offloaded modules; skipping device move." )
245234 else :
246- logger .info ("Skipping device move for quantized model - using device_map='auto'" )
235+ # Move model to the appropriate device only if quantization is disabled
236+ if not enable_quantization :
237+ device = "cuda" if torch .cuda .is_available () else "cpu"
238+ self .model = self .model .to (device )
247239
248240 # Capture model parameters after loading
249241 model_architecture = self .model .config .architectures [0 ] if hasattr (self .model .config , 'architectures' ) else 'Unknown'
@@ -374,8 +366,13 @@ async def generate(
374366 logger .warning (f"Invalid repetition_penalty value: { repetition_penalty } . Using model default." )
375367
376368 inputs = self .tokenizer (formatted_prompt , return_tensors = "pt" )
369+
370+ # Get the actual device of the model
371+ model_device = next (self .model .parameters ()).device
372+
373+ # Move inputs to the same device as the model
377374 for key in inputs :
378- inputs [key ] = inputs [key ].to (self . device )
375+ inputs [key ] = inputs [key ].to (model_device )
379376
380377 if stream :
381378 return self .async_stream_generate (inputs , gen_params )
@@ -436,6 +433,14 @@ def _stream_generate(
436433 top_k = 50
437434 repetition_penalty = 1.1
438435
436+ # Get the actual device of the model
437+ model_device = next (self .model .parameters ()).device
438+
439+ # Ensure inputs are on the same device as the model
440+ for key in inputs :
441+ if inputs [key ].device != model_device :
442+ inputs [key ] = inputs [key ].to (model_device )
443+
439444 with torch .no_grad ():
440445 for _ in range (max_length ):
441446 generate_params = {
@@ -463,6 +468,11 @@ def _stream_generate(
463468 yield new_token
464469 inputs = {"input_ids" : outputs ,
465470 "attention_mask" : torch .ones_like (outputs )}
471+
472+ # Ensure the updated inputs are on the correct device
473+ for key in inputs :
474+ if inputs [key ].device != model_device :
475+ inputs [key ] = inputs [key ].to (model_device )
466476
467477 except Exception as e :
468478 logger .error (f"Streaming generation failed: { str (e )} " )
@@ -500,8 +510,13 @@ async def async_stream_generate(self, inputs: Dict[str, torch.Tensor] = None, ge
500510
501511 # Tokenize the prompt
502512 inputs = self .tokenizer (formatted_prompt , return_tensors = "pt" )
513+
514+ # Get the actual device of the model
515+ model_device = next (self .model .parameters ()).device
516+
517+ # Move inputs to the same device as the model
503518 for key in inputs :
504- inputs [key ] = inputs [key ].to (self . device )
519+ inputs [key ] = inputs [key ].to (model_device )
505520
506521 # Now stream tokens using the prepared inputs and parameters
507522 for token in self ._stream_generate (inputs , gen_params = gen_params ):
@@ -528,6 +543,14 @@ def get_model_info(self) -> Dict[str, Any]:
528543 vram_required = self .model_config .get ("vram" , "Unknown" ) if isinstance (
529544 self .model_config , dict ) else "Unknown"
530545
546+ # Check environment variables for optimization settings
547+ enable_quantization = os .environ .get ('LOCALLAB_ENABLE_QUANTIZATION' , '' ).lower () not in ('false' , '0' , 'none' , '' )
548+ quantization_type = os .environ .get ('LOCALLAB_QUANTIZATION_TYPE' , '' ) if enable_quantization else "None"
549+
550+ # If quantization is disabled or type is empty, set to "None"
551+ if not enable_quantization or not quantization_type or quantization_type .lower () in ('none' , '' ):
552+ quantization_type = "None"
553+
531554 model_info = {
532555 "model_id" : self .current_model ,
533556 "model_name" : model_name ,
@@ -538,11 +561,11 @@ def get_model_info(self) -> Dict[str, Any]:
538561 "ram_required" : ram_required ,
539562 "vram_required" : vram_required ,
540563 "memory_used" : f"{ memory_used / (1024 * 1024 ):.2f} MB" ,
541- "quantization" : QUANTIZATION_TYPE if ENABLE_QUANTIZATION else "None" ,
564+ "quantization" : quantization_type ,
542565 "optimizations" : {
543- "attention_slicing" : ENABLE_ATTENTION_SLICING ,
544- "flash_attention" : ENABLE_FLASH_ATTENTION ,
545- "better_transformer" : ENABLE_BETTERTRANSFORMER
566+ "attention_slicing" : os . environ . get ( 'LOCALLAB_ENABLE_ATTENTION_SLICING' , '' ). lower () not in ( 'false' , '0' , 'none' , '' ) ,
567+ "flash_attention" : os . environ . get ( 'LOCALLAB_ENABLE_FLASH_ATTENTION' , '' ). lower () not in ( 'false' , '0' , 'none' , '' ) ,
568+ "better_transformer" : os . environ . get ( 'LOCALLAB_ENABLE_BETTERTRANSFORMER' , '' ). lower () not in ( 'false' , '0' , 'none' , '' )
546569 }
547570 }
548571
@@ -560,9 +583,9 @@ def get_model_info(self) -> Dict[str, Any]:
560583• Quantization: { Fore .YELLOW } { model_info ['quantization' ]} { Style .RESET_ALL }
561584
562585{ Fore .GREEN } 🔧 Optimizations{ Style .RESET_ALL }
563- • Attention Slicing: { Fore .YELLOW } { str (ENABLE_ATTENTION_SLICING )} { Style .RESET_ALL }
564- • Flash Attention: { Fore .YELLOW } { str (ENABLE_FLASH_ATTENTION )} { Style .RESET_ALL }
565- • Better Transformer: { Fore .YELLOW } { str (ENABLE_BETTERTRANSFORMER )} { Style .RESET_ALL }
586+ • Attention Slicing: { Fore .YELLOW } { str (model_info [ 'optimizations' ][ 'attention_slicing' ] )} { Style .RESET_ALL }
587+ • Flash Attention: { Fore .YELLOW } { str (model_info [ 'optimizations' ][ 'flash_attention' ] )} { Style .RESET_ALL }
588+ • Better Transformer: { Fore .YELLOW } { str (model_info [ 'optimizations' ][ 'better_transformer' ] )} { Style .RESET_ALL }
566589""" )
567590
568591 return model_info
0 commit comments