Skip to content

Commit 21f89ff

Browse files
JerryLifeclaude
andcommitted
fix: improve inference stability for large and quantized models
Add torch.autocast for bf16/fp16 models on CUDA to prevent dtype mismatches during forward passes. Skip redundant .to(device) when accelerate has already dispatched the model via device_map. Simplify model size tiers and batch sizing. Default skip_chat_template to True. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 0ea6779 commit 21f89ff

File tree

3 files changed

+43
-38
lines changed

3 files changed

+43
-38
lines changed

src/llm_dna/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class DNAExtractionConfig:
4949
gpu_id: Optional[int] = None
5050
log_level: str = "INFO"
5151
random_seed: int = 42
52-
skip_chat_template: bool = False
52+
skip_chat_template: bool = True
5353

5454

5555
@dataclass(slots=True)

src/llm_dna/dna/EmbeddingDNAExtractor.py

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -132,16 +132,10 @@ def _estimate_model_size(self, model: LLMWrapper) -> str:
132132
if param_count_billions is not None:
133133
if param_count_billions >= 60:
134134
return "very_large" # 60B+ parameters
135-
elif param_count_billions >= 13:
136-
return "very_large" # 13-60B parameters
137-
elif param_count_billions >= 7:
138-
return "medium" # 7-13B parameters
139-
elif param_count_billions >= 3:
140-
return "small" # 3-7B parameters
141-
elif param_count_billions >= 1:
142-
return "tiny" # 1-3B parameters
135+
elif param_count_billions >= 30:
136+
return "large" # 30-60B parameters
143137
else:
144-
return "micro" # <1B parameters
138+
return "standard" # <30B parameters
145139

146140
# Conservative default for unknown models
147141
self.logger.warning(f"Could not determine size for model {model.model_name}, using 'medium' batch size")
@@ -159,22 +153,13 @@ def _get_adaptive_batch_size(self, model: LLMWrapper, num_probes: int) -> int:
159153

160154
# Base batch sizes by model size
161155
size_to_batch = {
162-
"very_large": 1, # 60B+ models: process one at a time
163-
"large": 2, # 13-60B models: small batches
164-
"medium": 4, # 7-13B models: moderate batches
165-
"small": 8, # 3-7B models: larger batches
166-
"tiny": 16, # 1-3B models: large batches
167-
"micro": 32 # <1B models: very large batches
156+
"very_large": 1, # 60B+ models
157+
"large": 2, # 30-60B models
158+
"standard": 8, # <30B models
168159
}
169-
160+
170161
base_batch_size = size_to_batch[model_size]
171162

172-
# Further reduce batch size for very large probe sets
173-
if num_probes > 100:
174-
base_batch_size = max(1, base_batch_size // 2)
175-
elif num_probes > 50:
176-
base_batch_size = max(1, base_batch_size * 3 // 4)
177-
178163
self.adaptive_batch_size = base_batch_size
179164
self.logger.info(f"Using adaptive batch size {base_batch_size} for {model_size} model with {num_probes} probes")
180165

@@ -323,10 +308,17 @@ def _extract_decoder_only_features(self, model, probe_inputs: List[str], max_len
323308
# Ensure attention mask is in the correct dtype to avoid BFloat16/Half issues
324309
inputs['attention_mask'] = inputs['attention_mask'].to(dtype=torch.long)
325310

311+
# Detect model dtype for autocast; fall back to no-cast on CPU
312+
model_dtype = next(model.model.parameters()).dtype
313+
device_type = str(model.device).split(":")[0] # "cuda" or "cpu"
314+
use_autocast = device_type == "cuda" and model_dtype in (torch.bfloat16, torch.float16)
315+
326316
with torch.no_grad():
327-
# Perform a single forward pass to get hidden states (no generation)
328-
# Request hidden states explicitly for models that don't return them by default
329-
outputs = model.model(**inputs, output_hidden_states=True)
317+
if use_autocast:
318+
with torch.autocast(device_type=device_type, dtype=model_dtype):
319+
outputs = model.model(**inputs, output_hidden_states=True)
320+
else:
321+
outputs = model.model(**inputs, output_hidden_states=True)
330322

331323
# Get the hidden states from the last layer - handle different output formats
332324
if hasattr(outputs, 'last_hidden_state'):
@@ -336,8 +328,8 @@ def _extract_decoder_only_features(self, model, probe_inputs: List[str], max_len
336328
last_hidden_state = outputs.hidden_states[-1]
337329
else:
338330
raise ValueError(f"Model output does not contain accessible hidden states. Available attributes: {list(outputs.__dict__.keys())}")
339-
340-
# Convert dtype early to avoid precision issues
331+
332+
# Convert to float32 for stable downstream computation
341333
last_hidden_state = last_hidden_state.to(dtype=torch.float32)
342334

343335
# Find the index of the last non-padding token for each sequence
@@ -424,15 +416,22 @@ def _extract_encoder_decoder_features(self, model, probe_inputs: List[str], max_
424416
# Ensure attention mask is in the correct dtype to avoid BFloat16/Half issues
425417
inputs['attention_mask'] = inputs['attention_mask'].to(dtype=torch.long)
426418

419+
model_dtype = next(model.model.parameters()).dtype
420+
device_type = str(model.device).split(":")[0]
421+
use_autocast = device_type == "cuda" and model_dtype in (torch.bfloat16, torch.float16)
422+
427423
with torch.no_grad():
428-
# Get model outputs with encoder hidden states
429-
outputs = model.model(**inputs, output_hidden_states=True)
430-
424+
if use_autocast:
425+
with torch.autocast(device_type=device_type, dtype=model_dtype):
426+
outputs = model.model(**inputs, output_hidden_states=True)
427+
else:
428+
outputs = model.model(**inputs, output_hidden_states=True)
429+
431430
# Check if encoder_last_hidden_state exists
432431
if not hasattr(outputs, 'encoder_last_hidden_state'):
433432
self.logger.warning(f"Model does not provide encoder_last_hidden_state for batch {i}")
434433
continue
435-
434+
436435
# Get the encoder's final hidden states and convert dtype early
437436
encoder_hidden_states = outputs.encoder_last_hidden_state.to(dtype=torch.float32)
438437
attention_mask = inputs['attention_mask']

src/llm_dna/models/ModelWrapper.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -504,12 +504,18 @@ def _load_model_and_tokenizer(self):
504504

505505
# Ensure all model components are on the same device
506506
if quantization_config is None:
507-
self.logger.info(f"Moving non-quantized model to device: {self.device}")
508-
self.model = self.model.to(self.device)
509-
# Ensure all parameters are on the same device
510-
for param in self.model.parameters():
511-
if param.device != torch.device(self.device):
512-
param.data = param.data.to(self.device)
507+
# If accelerate already dispatched via device_map, skip .to()
508+
if hasattr(self.model, "hf_device_map"):
509+
self.logger.info(
510+
f"Model already dispatched via device_map: {self.model.hf_device_map}"
511+
)
512+
else:
513+
self.logger.info(f"Moving non-quantized model to device: {self.device}")
514+
self.model = self.model.to(self.device)
515+
# Ensure all parameters are on the same device
516+
for param in self.model.parameters():
517+
if param.device != torch.device(self.device):
518+
param.data = param.data.to(self.device)
513519

514520
# Verify final device placement
515521
devices = set()

0 commit comments

Comments
 (0)