Skip to content

Commit 38830a7

Browse files
committed
decoding updates
1 parent 5d471d1 commit 38830a7

File tree

2 files changed

+251
-1
lines changed

2 files changed

+251
-1
lines changed

optillm/cot_decoding.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import numpy as np
55

66
def get_device():
7-
if torch.cuda.is_available():
7+
if torch.backends.mps.is_available():
8+
return torch.device("mps")
9+
elif torch.cuda.is_available():
810
return torch.device("cuda")
911
else:
1012
return torch.device("cpu")
@@ -143,3 +145,18 @@ def cot_decode(
143145
else:
144146
return max(paths, key=lambda x: x[1])
145147

148+
# Usage example
149+
# from transformers import AutoModelForCausalLM, AutoTokenizer
150+
151+
# model_name = "Qwen/Qwen2.5-0.5B-Instruct"
152+
# model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager")
153+
# tokenizer = AutoTokenizer.from_pretrained(model_name)
154+
155+
# messages = [
156+
# {"role": "user", "content": "In a dance class of 20 students, 20% enrolled in contemporary dance, 25% of the remaining enrolled in jazz dance, and the rest enrolled in hip-hop dance. What percentage of the entire students enrolled in hip-hop dance?"}
157+
# ]
158+
159+
# # Generate the response using CoT decoding
160+
# print(f"Using device: {get_device()}")
161+
# result, confidence = cot_decode(model, tokenizer, messages, aggregate_paths=True, max_new_tokens=512)
162+
# print(f"CoT Decoding:\n {result}")

optillm/entropy_decoding.py

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from transformers import PreTrainedModel, PreTrainedTokenizer
4+
from typing import List, Tuple, Dict, Optional
5+
import logging
6+
7+
# Set up logging
8+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
9+
10+
# Device selection
11+
if torch.backends.mps.is_available():
12+
device = torch.device("mps")
13+
elif torch.cuda.is_available():
14+
device = torch.device("cuda")
15+
else:
16+
device = torch.device("cpu")
17+
18+
logging.info(f"Using device: {device}")
19+
20+
LN_2 = 0.69314718056 # ln(2)
21+
22+
def calculate_varentropy_logsoftmax(logits: torch.Tensor, axis: int = -1) -> Tuple[torch.Tensor, torch.Tensor]:
23+
log_probs = F.log_softmax(logits, dim=axis)
24+
probs = torch.exp(log_probs)
25+
entropy = -torch.sum(probs * log_probs, dim=axis) / LN_2 # Convert to base-2
26+
varentropy = torch.sum(probs * (log_probs / LN_2 + entropy.unsqueeze(-1))**2, dim=axis)
27+
return entropy, varentropy
28+
29+
def calculate_attention_metrics(attention_scores: torch.Tensor) -> Dict[str, torch.Tensor]:
30+
attention_probs = F.softmax(attention_scores, dim=-1)
31+
attn_entropy = -torch.sum(attention_probs * torch.log2(torch.clamp(attention_probs, 1e-10, 1.0)), dim=-1)
32+
attn_varentropy = torch.var(attn_entropy, dim=-1)
33+
34+
attn_varentropy = torch.where(torch.isnan(attn_varentropy), torch.zeros_like(attn_varentropy), attn_varentropy)
35+
mean_attention = torch.mean(attention_probs, dim=1)
36+
agreement = torch.mean(torch.abs(attention_probs - mean_attention.unsqueeze(1)), dim=(1, 2))
37+
38+
interaction_strength = torch.mean(torch.abs(attention_scores), dim=(1, 2, 3))
39+
40+
return {
41+
"attn_entropy": torch.mean(attn_entropy),
42+
"attn_varentropy": torch.mean(attn_varentropy),
43+
"agreement": torch.mean(agreement),
44+
"interaction_strength": interaction_strength
45+
}
46+
47+
def _sample(logits: torch.Tensor, temperature=0.666, top_p=0.90, top_k=27, min_p: float = 0.0, generator: torch.Generator = None) -> torch.Tensor:
48+
bsz = logits.shape[0]
49+
logit = logits[:, -1]
50+
probs = F.softmax(logit / temperature, dim=-1)
51+
52+
if min_p > 0.0:
53+
p_max = torch.max(probs, dim=-1, keepdim=True).values
54+
indices_to_remove = probs < (min_p * p_max)
55+
logit = torch.where(indices_to_remove, torch.full_like(logit, float('-inf')), logit)
56+
57+
top_k_probs, top_k_indices = torch.topk(probs, k=min(top_k, probs.shape[-1]))
58+
probs_sort = torch.flip(top_k_probs, dims=[-1])
59+
probs_idx = torch.flip(top_k_indices, dims=[-1])
60+
probs_sum = torch.cumsum(probs_sort, dim=-1)
61+
mask = torch.where(probs_sum - probs_sort > top_p, torch.tensor(1.0, device=device), torch.tensor(0.0, device=device))
62+
probs_sort = probs_sort * (1 - mask)
63+
probs_sort = probs_sort / torch.sum(probs_sort, dim=-1, keepdim=True)
64+
next_token = torch.multinomial(probs_sort, 1, generator=generator)
65+
next_token_g = torch.gather(probs_idx, -1, next_token.reshape(bsz, 1).to(torch.int64))
66+
return next_token_g.to(torch.int32)
67+
68+
def adaptive_sample(logits: torch.Tensor, metrics: Dict[str, torch.Tensor],
69+
gen_tokens: torch.Tensor, n_samples: int,
70+
base_temp: float = 0.666, base_top_p: float = 0.90, base_top_k: int = 40, base_min_p: float = 0.03,
71+
generator: torch.Generator = None) -> torch.Tensor:
72+
logits_uncertainty = metrics["logits_entropy"] + metrics["logits_varentropy"]
73+
attn_uncertainty = metrics["attn_entropy"] + metrics["attn_varentropy"]
74+
75+
temperature = base_temp * (1 + 0.3 * logits_uncertainty + 0.2 * attn_uncertainty - 0.2 * metrics["agreement"])
76+
top_p = torch.clamp(base_top_p * (1 + 0.1 * metrics["attn_varentropy"]), 0.1, 1.0)
77+
top_k = int(torch.clamp(
78+
torch.round(torch.tensor(base_top_k) * (1 + 0.3 * metrics["interaction_strength"].item() - 0.2 * metrics["agreement"].item())),
79+
min=1,
80+
max=100
81+
).item())
82+
min_p = torch.clamp(base_min_p * (1 - 0.5 * logits_uncertainty), 0.01, 0.5)
83+
84+
logging.debug(f"Adaptive sampling params: temp={temperature:.3f}, top_p={top_p:.3f}, top_k={top_k}, min_p={min_p:.3f}")
85+
86+
samples = []
87+
for _ in range(n_samples):
88+
sample = _sample(logits, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, generator=generator)
89+
samples.append(sample)
90+
91+
def score_sample(sample):
92+
sample_flat = sample.flatten().to(torch.long)
93+
one_hot = F.one_hot(sample_flat, logits.shape[-1])
94+
log_probs = F.log_softmax(logits, dim=-1).view(-1, logits.shape[-1])
95+
log_prob = torch.sum(log_probs * one_hot)
96+
97+
confidence_score = (
98+
(1 - metrics["logits_entropy"]) * 0.1 +
99+
(1 - metrics["attn_entropy"]) * 0.2 +
100+
(1 - metrics["logits_varentropy"]) * 0.3 +
101+
(1 - metrics["attn_varentropy"]) * 0.4 +
102+
metrics["agreement"] * 0.5 +
103+
metrics["interaction_strength"] * 0.6
104+
)
105+
return log_prob + confidence_score
106+
107+
sample_scores = torch.stack([score_sample(sample) for sample in samples])
108+
best_sample_idx = torch.argmax(sample_scores)
109+
return samples[best_sample_idx]
110+
111+
def entropy_decode(
112+
model: PreTrainedModel,
113+
tokenizer: PreTrainedTokenizer,
114+
messages: List[Dict[str, str]],
115+
max_new_tokens: int = 512,
116+
temperature: float = 0.666,
117+
top_p: float = 0.90,
118+
top_k: int = 27,
119+
min_p: float = 0.03,
120+
generator: torch.Generator = torch.Generator(device=device).manual_seed(1337)
121+
) -> str:
122+
model.to(device)
123+
logging.info("Starting entropy decoding")
124+
125+
if hasattr(tokenizer, 'chat_template') and tokenizer.chat_template:
126+
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
127+
else:
128+
input_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
129+
input_text += "\nassistant:"
130+
131+
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
132+
attention_mask = torch.ones_like(input_ids).to(device)
133+
134+
if tokenizer.pad_token_id is None:
135+
tokenizer.pad_token_id = tokenizer.eos_token_id
136+
137+
generated_tokens = []
138+
gen_tokens = input_ids
139+
past_key_values = None
140+
stop = torch.tensor([tokenizer.eos_token_id], device=device, dtype=torch.int32)
141+
142+
for step in range(max_new_tokens):
143+
logging.info(f"Generation step: {step + 1}")
144+
with torch.no_grad():
145+
outputs = model(
146+
input_ids if past_key_values is None else input_ids[:, -1:],
147+
attention_mask=attention_mask,
148+
past_key_values=past_key_values,
149+
use_cache=True,
150+
output_attentions=True,
151+
)
152+
153+
logits = outputs.logits[:, -1:, :]
154+
attention_scores = outputs.attentions[-1]
155+
past_key_values = outputs.past_key_values
156+
157+
entropy, varentropy = calculate_varentropy_logsoftmax(logits)
158+
attention_metrics = calculate_attention_metrics(attention_scores)
159+
metrics = {
160+
"logits_entropy": entropy,
161+
"logits_varentropy": varentropy,
162+
**attention_metrics
163+
}
164+
165+
logging.debug(f"Metrics: entropy={entropy.item():.3f}, varentropy={varentropy.item():.3f}")
166+
167+
if entropy < 0.1 and varentropy < 0.1:
168+
next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True).to(torch.int32)
169+
logging.debug("Using greedy sampling")
170+
elif entropy > 3.0 and varentropy < 0.1:
171+
if not torch.isin(gen_tokens[:,-1], torch.tensor([2564], device=device)).any():
172+
next_token = torch.tensor([[2564]], dtype=torch.int32, device=device)
173+
logging.debug("Inserting clarification token")
174+
else:
175+
temp_adj = 1.3 + 0.2 * attention_metrics["attn_entropy"]
176+
next_token = _sample(logits, temperature=min(1.5, temperature * temp_adj), top_p=top_p, top_k=top_k, min_p=min_p, generator=generator)
177+
logging.debug(f"Using adjusted temperature sampling: {temp_adj:.3f}")
178+
elif entropy < 5.0 and varentropy > 5.0:
179+
temp_adj = 1.2 + 0.3 * attention_metrics["interaction_strength"]
180+
top_k_adj = max(5, int(top_k * (1 + 0.5 * (1 - attention_metrics["agreement"]))))
181+
next_token = _sample(logits, temperature=min(1.5, temperature * temp_adj), top_p=top_p, top_k=top_k_adj, min_p=min_p, generator=generator)
182+
logging.debug(f"Using exploration sampling: temp={temp_adj:.3f}, top_k={top_k_adj}")
183+
elif entropy > 5.0 and varentropy > 5.0:
184+
temp_adj = 2.0 + 0.5 * attention_metrics["attn_varentropy"]
185+
top_p_adj = max(0.5, top_p - 0.2 * attention_metrics["attn_entropy"])
186+
next_token = _sample(logits, temperature=max(2.0, temperature * temp_adj), top_p=top_p_adj, top_k=top_k, min_p=min_p, generator=generator)
187+
logging.debug(f"Using high uncertainty sampling: temp={temp_adj:.3f}, top_p={top_p_adj:.3f}")
188+
else:
189+
next_token = adaptive_sample(
190+
logits,
191+
metrics,
192+
gen_tokens,
193+
n_samples=5,
194+
base_temp=temperature,
195+
base_top_p=top_p,
196+
base_top_k=top_k,
197+
base_min_p=min_p,
198+
generator=generator
199+
)
200+
logging.debug("Using adaptive sampling")
201+
202+
generated_tokens.append(next_token.item())
203+
gen_tokens = torch.cat((gen_tokens, next_token), dim=1)
204+
input_ids = torch.cat([input_ids, next_token], dim=-1)
205+
attention_mask = torch.cat([attention_mask, torch.ones((1, 1), device=device, dtype=torch.long)], dim=-1)
206+
207+
logging.debug(f"Generated token: {tokenizer.decode([next_token.item()])}")
208+
209+
if torch.isin(next_token, stop).any():
210+
logging.info("Reached stop token. Ending generation.")
211+
break
212+
213+
generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
214+
logging.info("Finished entropy decoding")
215+
logging.info(f"Generated text: {generated_text}")
216+
217+
return generated_text
218+
219+
# Usage example
220+
from transformers import AutoModelForCausalLM, AutoTokenizer
221+
222+
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
223+
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager")
224+
tokenizer = AutoTokenizer.from_pretrained(model_name)
225+
226+
messages = [
227+
{"role": "user", "content": "In a dance class of 20 students, 20% enrolled in contemporary dance, 25% of the remaining enrolled in jazz dance, and the rest enrolled in hip-hop dance. What percentage of the entire students enrolled in hip-hop dance?"}
228+
]
229+
230+
logging.info("Starting entropy decoding process")
231+
result = entropy_decode(model, tokenizer, messages)
232+
print(f"Entropy Decoding Result:\n{result}")
233+
logging.info("Entropy decoding process completed")

0 commit comments

Comments
 (0)