Skip to content

Commit d04eea1

Browse files
committed
Update entropy_decoding.py
1 parent aef846c commit d04eea1

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

optillm/entropy_decoding.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def entropy_decode(
141141
stop = torch.tensor([tokenizer.eos_token_id], device=device, dtype=torch.int32)
142142

143143
for step in range(max_new_tokens):
144-
logging.info(f"Generation step: {step + 1}")
144+
logging.debug(f"Generation step: {step + 1}")
145145
with torch.no_grad():
146146
outputs = model(
147147
input_ids if past_key_values is None else input_ids[:, -1:],
@@ -173,17 +173,17 @@ def entropy_decode(
173173
next_token = torch.tensor([[2564]], dtype=torch.int32, device=device)
174174
logging.debug("Inserting clarification token")
175175
else:
176-
temp_adj = 1.3 + 0.2 * attention_metrics["attn_entropy"]
176+
temp_adj = 1.3 + 0.2 * attention_metrics["attn_entropy"].item()
177177
next_token = _sample(logits, temperature=min(1.5, temperature * temp_adj), top_p=top_p, top_k=top_k, min_p=min_p, generator=generator)
178178
logging.debug(f"Using adjusted temperature sampling: {temp_adj:.3f}")
179179
elif entropy < 5.0 and varentropy > 5.0:
180-
temp_adj = 1.2 + 0.3 * attention_metrics["interaction_strength"]
181-
top_k_adj = max(5, int(top_k * (1 + 0.5 * (1 - attention_metrics["agreement"]))))
180+
temp_adj = 1.2 + 0.3 * attention_metrics["interaction_strength"].item()
181+
top_k_adj = max(5, int(top_k * (1 + 0.5 * (1 - attention_metrics["agreement"].item()))))
182182
next_token = _sample(logits, temperature=min(1.5, temperature * temp_adj), top_p=top_p, top_k=top_k_adj, min_p=min_p, generator=generator)
183183
logging.debug(f"Using exploration sampling: temp={temp_adj:.3f}, top_k={top_k_adj}")
184184
elif entropy > 5.0 and varentropy > 5.0:
185-
temp_adj = 2.0 + 0.5 * attention_metrics["attn_varentropy"]
186-
top_p_adj = max(0.5, top_p - 0.2 * attention_metrics["attn_entropy"])
185+
temp_adj = 2.0 + 0.5 * attention_metrics["attn_varentropy"].item()
186+
top_p_adj = max(0.5, top_p - 0.2 * attention_metrics["attn_entropy"].item())
187187
next_token = _sample(logits, temperature=max(2.0, temperature * temp_adj), top_p=top_p_adj, top_k=top_k, min_p=min_p, generator=generator)
188188
logging.debug(f"Using high uncertainty sampling: temp={temp_adj:.3f}, top_p={top_p_adj:.3f}")
189189
else:
@@ -218,17 +218,17 @@ def entropy_decode(
218218
return generated_text
219219

220220
# Usage example
221-
from transformers import AutoModelForCausalLM, AutoTokenizer
221+
# from transformers import AutoModelForCausalLM, AutoTokenizer
222222

223-
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
224-
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager")
225-
tokenizer = AutoTokenizer.from_pretrained(model_name)
223+
# model_name = "Qwen/Qwen2.5-0.5B-Instruct"
224+
# model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager")
225+
# tokenizer = AutoTokenizer.from_pretrained(model_name)
226226

227-
messages = [
228-
{"role": "user", "content": "In a dance class of 20 students, 20% enrolled in contemporary dance, 25% of the remaining enrolled in jazz dance, and the rest enrolled in hip-hop dance. What percentage of the entire students enrolled in hip-hop dance?"}
229-
]
227+
# messages = [
228+
# {"role": "user", "content": "In a dance class of 20 students, 20% enrolled in contemporary dance, 25% of the remaining enrolled in jazz dance, and the rest enrolled in hip-hop dance. What percentage of the entire students enrolled in hip-hop dance?"}
229+
# ]
230230

231-
logging.info("Starting entropy decoding process")
232-
result = entropy_decode(model, tokenizer, messages)
233-
print(f"Entropy Decoding Result:\n{result}")
234-
logging.info("Entropy decoding process completed")
231+
# logging.info("Starting entropy decoding process")
232+
# result = entropy_decode(model, tokenizer, messages)
233+
# print(f"Entropy Decoding Result:\n{result}")
234+
# logging.info("Entropy decoding process completed")

0 commit comments

Comments
 (0)