@@ -202,11 +202,16 @@ async def load_model(self, model_id: str) -> bool:
202202 log_model_unloaded (prev_model )
203203
204204 hf_token = os .getenv ("HF_TOKEN" )
205- config = self ._get_quantization_config ()
206-
205+
207206 # Check quantization settings from environment variables
208207 enable_quantization = os .environ .get ('LOCALLAB_ENABLE_QUANTIZATION' , '' ).lower () not in ('false' , '0' , 'none' , '' )
209208 quantization_type = os .environ .get ('LOCALLAB_QUANTIZATION_TYPE' , '' ) if enable_quantization else "None"
209+
210+ # Get configuration based on quantization settings
211+ config = self ._get_quantization_config () if enable_quantization else {
212+ "torch_dtype" : torch .float16 if torch .cuda .is_available () else torch .float32 ,
213+ "device_map" : "auto" # Always use device_map="auto" for automatic placement
214+ }
210215
211216 if config and config .get ("quantization_config" ):
212217 logger .info (f"Using quantization config: { quantization_type } " )
@@ -221,28 +226,23 @@ async def load_model(self, model_id: str) -> bool:
221226 token = hf_token
222227 )
223228
229+ # Load the model with device_map="auto" to let the library handle device placement
224230 self .model = AutoModelForCausalLM .from_pretrained (
225231 model_id ,
226232 trust_remote_code = True ,
227233 token = hf_token ,
228234 ** config
229235 )
230-
231- # Check if the model has offloaded modules
232- if hasattr (self .model , 'is_offloaded' ) and self .model .is_offloaded :
233- logger .warning ("Model has offloaded modules; skipping device move." )
234- else :
235- # Move model to the appropriate device only if quantization is disabled
236- if not enable_quantization :
237- device = "cuda" if torch .cuda .is_available () else "cpu"
238- self .model = self .model .to (device )
236+
237+ logger .info (f"Model loaded with device_map='auto' for automatic placement" )
239238
240239 # Capture model parameters after loading
241240 model_architecture = self .model .config .architectures [0 ] if hasattr (self .model .config , 'architectures' ) else 'Unknown'
242241 memory_used = torch .cuda .memory_allocated () if torch .cuda .is_available () else 'N/A'
243242 logger .info (f"Model architecture: { model_architecture } " )
244243 logger .info (f"Memory used: { memory_used } " )
245244
245+ # Apply optimizations if needed
246246 self .model = self ._apply_optimizations (self .model )
247247
248248 self .current_model = model_id
0 commit comments