@@ -132,16 +132,10 @@ def _estimate_model_size(self, model: LLMWrapper) -> str:
132132 if param_count_billions is not None :
133133 if param_count_billions >= 60 :
134134 return "very_large" # 60B+ parameters
135- elif param_count_billions >= 13 :
136- return "very_large" # 13-60B parameters
137- elif param_count_billions >= 7 :
138- return "medium" # 7-13B parameters
139- elif param_count_billions >= 3 :
140- return "small" # 3-7B parameters
141- elif param_count_billions >= 1 :
142- return "tiny" # 1-3B parameters
135+ elif param_count_billions >= 30 :
136+ return "large" # 30-60B parameters
143137 else :
144- return "micro " # <1B parameters
138+ return "standard " # <30B parameters
145139
146140 # Conservative default for unknown models
147141 self .logger .warning (f"Could not determine size for model { model .model_name } , using 'medium' batch size" )
@@ -159,22 +153,13 @@ def _get_adaptive_batch_size(self, model: LLMWrapper, num_probes: int) -> int:
159153
160154 # Base batch sizes by model size
161155 size_to_batch = {
162- "very_large" : 1 , # 60B+ models: process one at a time
163- "large" : 2 , # 13-60B models: small batches
164- "medium" : 4 , # 7-13B models: moderate batches
165- "small" : 8 , # 3-7B models: larger batches
166- "tiny" : 16 , # 1-3B models: large batches
167- "micro" : 32 # <1B models: very large batches
156+ "very_large" : 1 , # 60B+ models
157+ "large" : 2 , # 30-60B models
158+ "standard" : 8 , # <30B models
168159 }
169-
160+
170161 base_batch_size = size_to_batch [model_size ]
171162
172- # Further reduce batch size for very large probe sets
173- if num_probes > 100 :
174- base_batch_size = max (1 , base_batch_size // 2 )
175- elif num_probes > 50 :
176- base_batch_size = max (1 , base_batch_size * 3 // 4 )
177-
178163 self .adaptive_batch_size = base_batch_size
179164 self .logger .info (f"Using adaptive batch size { base_batch_size } for { model_size } model with { num_probes } probes" )
180165
@@ -323,10 +308,17 @@ def _extract_decoder_only_features(self, model, probe_inputs: List[str], max_len
323308 # Ensure attention mask is in the correct dtype to avoid BFloat16/Half issues
324309 inputs ['attention_mask' ] = inputs ['attention_mask' ].to (dtype = torch .long )
325310
311+ # Detect model dtype for autocast; fall back to no-cast on CPU
312+ model_dtype = next (model .model .parameters ()).dtype
313+ device_type = str (model .device ).split (":" )[0 ] # "cuda" or "cpu"
314+ use_autocast = device_type == "cuda" and model_dtype in (torch .bfloat16 , torch .float16 )
315+
326316 with torch .no_grad ():
327- # Perform a single forward pass to get hidden states (no generation)
328- # Request hidden states explicitly for models that don't return them by default
329- outputs = model .model (** inputs , output_hidden_states = True )
317+ if use_autocast :
318+ with torch .autocast (device_type = device_type , dtype = model_dtype ):
319+ outputs = model .model (** inputs , output_hidden_states = True )
320+ else :
321+ outputs = model .model (** inputs , output_hidden_states = True )
330322
331323 # Get the hidden states from the last layer - handle different output formats
332324 if hasattr (outputs , 'last_hidden_state' ):
@@ -336,8 +328,8 @@ def _extract_decoder_only_features(self, model, probe_inputs: List[str], max_len
336328 last_hidden_state = outputs .hidden_states [- 1 ]
337329 else :
338330 raise ValueError (f"Model output does not contain accessible hidden states. Available attributes: { list (outputs .__dict__ .keys ())} " )
339-
340- # Convert dtype early to avoid precision issues
331+
332+ # Convert to float32 for stable downstream computation
341333 last_hidden_state = last_hidden_state .to (dtype = torch .float32 )
342334
343335 # Find the index of the last non-padding token for each sequence
@@ -424,15 +416,22 @@ def _extract_encoder_decoder_features(self, model, probe_inputs: List[str], max_
424416 # Ensure attention mask is in the correct dtype to avoid BFloat16/Half issues
425417 inputs ['attention_mask' ] = inputs ['attention_mask' ].to (dtype = torch .long )
426418
419+ model_dtype = next (model .model .parameters ()).dtype
420+ device_type = str (model .device ).split (":" )[0 ]
421+ use_autocast = device_type == "cuda" and model_dtype in (torch .bfloat16 , torch .float16 )
422+
427423 with torch .no_grad ():
428- # Get model outputs with encoder hidden states
429- outputs = model .model (** inputs , output_hidden_states = True )
430-
424+ if use_autocast :
425+ with torch .autocast (device_type = device_type , dtype = model_dtype ):
426+ outputs = model .model (** inputs , output_hidden_states = True )
427+ else :
428+ outputs = model .model (** inputs , output_hidden_states = True )
429+
431430 # Check if encoder_last_hidden_state exists
432431 if not hasattr (outputs , 'encoder_last_hidden_state' ):
433432 self .logger .warning (f"Model does not provide encoder_last_hidden_state for batch { i } " )
434433 continue
435-
434+
436435 # Get the encoder's final hidden states and convert dtype early
437436 encoder_hidden_states = outputs .encoder_last_hidden_state .to (dtype = torch .float32 )
438437 attention_mask = inputs ['attention_mask' ]
0 commit comments