1+ import torch
2+ import torch .nn .functional as F
3+ from transformers import PreTrainedModel , PreTrainedTokenizer
4+ from typing import List , Tuple , Dict , Optional
5+ import logging
6+
7+ # Set up logging
8+ logging .basicConfig (level = logging .INFO , format = '%(asctime)s - %(levelname)s - %(message)s' )
9+
10+ # Device selection
11+ if torch .backends .mps .is_available ():
12+ device = torch .device ("mps" )
13+ elif torch .cuda .is_available ():
14+ device = torch .device ("cuda" )
15+ else :
16+ device = torch .device ("cpu" )
17+
18+ logging .info (f"Using device: { device } " )
19+
20+ LN_2 = 0.69314718056 # ln(2)
21+
22+ def calculate_varentropy_logsoftmax (logits : torch .Tensor , axis : int = - 1 ) -> Tuple [torch .Tensor , torch .Tensor ]:
23+ log_probs = F .log_softmax (logits , dim = axis )
24+ probs = torch .exp (log_probs )
25+ entropy = - torch .sum (probs * log_probs , dim = axis ) / LN_2 # Convert to base-2
26+ varentropy = torch .sum (probs * (log_probs / LN_2 + entropy .unsqueeze (- 1 ))** 2 , dim = axis )
27+ return entropy , varentropy
28+
29+ def calculate_attention_metrics (attention_scores : torch .Tensor ) -> Dict [str , torch .Tensor ]:
30+ attention_probs = F .softmax (attention_scores , dim = - 1 )
31+ attn_entropy = - torch .sum (attention_probs * torch .log2 (torch .clamp (attention_probs , 1e-10 , 1.0 )), dim = - 1 )
32+ attn_varentropy = torch .var (attn_entropy , dim = - 1 )
33+
34+ attn_varentropy = torch .where (torch .isnan (attn_varentropy ), torch .zeros_like (attn_varentropy ), attn_varentropy )
35+ mean_attention = torch .mean (attention_probs , dim = 1 )
36+ agreement = torch .mean (torch .abs (attention_probs - mean_attention .unsqueeze (1 )), dim = (1 , 2 ))
37+
38+ interaction_strength = torch .mean (torch .abs (attention_scores ), dim = (1 , 2 , 3 ))
39+
40+ return {
41+ "attn_entropy" : torch .mean (attn_entropy ),
42+ "attn_varentropy" : torch .mean (attn_varentropy ),
43+ "agreement" : torch .mean (agreement ),
44+ "interaction_strength" : interaction_strength
45+ }
46+
47+ def _sample (logits : torch .Tensor , temperature = 0.666 , top_p = 0.90 , top_k = 27 , min_p : float = 0.0 , generator : torch .Generator = None ) -> torch .Tensor :
48+ bsz = logits .shape [0 ]
49+ logit = logits [:, - 1 ]
50+ probs = F .softmax (logit / temperature , dim = - 1 )
51+
52+ if min_p > 0.0 :
53+ p_max = torch .max (probs , dim = - 1 , keepdim = True ).values
54+ indices_to_remove = probs < (min_p * p_max )
55+ logit = torch .where (indices_to_remove , torch .full_like (logit , float ('-inf' )), logit )
56+
57+ top_k_probs , top_k_indices = torch .topk (probs , k = min (top_k , probs .shape [- 1 ]))
58+ probs_sort = torch .flip (top_k_probs , dims = [- 1 ])
59+ probs_idx = torch .flip (top_k_indices , dims = [- 1 ])
60+ probs_sum = torch .cumsum (probs_sort , dim = - 1 )
61+ mask = torch .where (probs_sum - probs_sort > top_p , torch .tensor (1.0 , device = device ), torch .tensor (0.0 , device = device ))
62+ probs_sort = probs_sort * (1 - mask )
63+ probs_sort = probs_sort / torch .sum (probs_sort , dim = - 1 , keepdim = True )
64+ next_token = torch .multinomial (probs_sort , 1 , generator = generator )
65+ next_token_g = torch .gather (probs_idx , - 1 , next_token .reshape (bsz , 1 ).to (torch .int64 ))
66+ return next_token_g .to (torch .int32 )
67+
68+ def adaptive_sample (logits : torch .Tensor , metrics : Dict [str , torch .Tensor ],
69+ gen_tokens : torch .Tensor , n_samples : int ,
70+ base_temp : float = 0.666 , base_top_p : float = 0.90 , base_top_k : int = 40 , base_min_p : float = 0.03 ,
71+ generator : torch .Generator = None ) -> torch .Tensor :
72+ logits_uncertainty = metrics ["logits_entropy" ] + metrics ["logits_varentropy" ]
73+ attn_uncertainty = metrics ["attn_entropy" ] + metrics ["attn_varentropy" ]
74+
75+ temperature = base_temp * (1 + 0.3 * logits_uncertainty + 0.2 * attn_uncertainty - 0.2 * metrics ["agreement" ])
76+ top_p = torch .clamp (base_top_p * (1 + 0.1 * metrics ["attn_varentropy" ]), 0.1 , 1.0 )
77+ top_k = int (torch .clamp (
78+ torch .round (torch .tensor (base_top_k ) * (1 + 0.3 * metrics ["interaction_strength" ].item () - 0.2 * metrics ["agreement" ].item ())),
79+ min = 1 ,
80+ max = 100
81+ ).item ())
82+ min_p = torch .clamp (base_min_p * (1 - 0.5 * logits_uncertainty ), 0.01 , 0.5 )
83+
84+ logging .debug (f"Adaptive sampling params: temp={ temperature :.3f} , top_p={ top_p :.3f} , top_k={ top_k } , min_p={ min_p :.3f} " )
85+
86+ samples = []
87+ for _ in range (n_samples ):
88+ sample = _sample (logits , temperature = temperature , top_p = top_p , top_k = top_k , min_p = min_p , generator = generator )
89+ samples .append (sample )
90+
91+ def score_sample (sample ):
92+ sample_flat = sample .flatten ().to (torch .long )
93+ one_hot = F .one_hot (sample_flat , logits .shape [- 1 ])
94+ log_probs = F .log_softmax (logits , dim = - 1 ).view (- 1 , logits .shape [- 1 ])
95+ log_prob = torch .sum (log_probs * one_hot )
96+
97+ confidence_score = (
98+ (1 - metrics ["logits_entropy" ]) * 0.1 +
99+ (1 - metrics ["attn_entropy" ]) * 0.2 +
100+ (1 - metrics ["logits_varentropy" ]) * 0.3 +
101+ (1 - metrics ["attn_varentropy" ]) * 0.4 +
102+ metrics ["agreement" ] * 0.5 +
103+ metrics ["interaction_strength" ] * 0.6
104+ )
105+ return log_prob + confidence_score
106+
107+ sample_scores = torch .stack ([score_sample (sample ) for sample in samples ])
108+ best_sample_idx = torch .argmax (sample_scores )
109+ return samples [best_sample_idx ]
110+
111+ def entropy_decode (
112+ model : PreTrainedModel ,
113+ tokenizer : PreTrainedTokenizer ,
114+ messages : List [Dict [str , str ]],
115+ max_new_tokens : int = 512 ,
116+ temperature : float = 0.666 ,
117+ top_p : float = 0.90 ,
118+ top_k : int = 27 ,
119+ min_p : float = 0.03 ,
120+ generator : torch .Generator = torch .Generator (device = device ).manual_seed (1337 )
121+ ) -> str :
122+ model .to (device )
123+ logging .info ("Starting entropy decoding" )
124+
125+ if hasattr (tokenizer , 'chat_template' ) and tokenizer .chat_template :
126+ input_text = tokenizer .apply_chat_template (messages , tokenize = False , add_generation_prompt = True )
127+ else :
128+ input_text = "\n " .join ([f"{ msg ['role' ]} : { msg ['content' ]} " for msg in messages ])
129+ input_text += "\n assistant:"
130+
131+ input_ids = tokenizer .encode (input_text , return_tensors = "pt" ).to (device )
132+ attention_mask = torch .ones_like (input_ids ).to (device )
133+
134+ if tokenizer .pad_token_id is None :
135+ tokenizer .pad_token_id = tokenizer .eos_token_id
136+
137+ generated_tokens = []
138+ gen_tokens = input_ids
139+ past_key_values = None
140+ stop = torch .tensor ([tokenizer .eos_token_id ], device = device , dtype = torch .int32 )
141+
142+ for step in range (max_new_tokens ):
143+ logging .info (f"Generation step: { step + 1 } " )
144+ with torch .no_grad ():
145+ outputs = model (
146+ input_ids if past_key_values is None else input_ids [:, - 1 :],
147+ attention_mask = attention_mask ,
148+ past_key_values = past_key_values ,
149+ use_cache = True ,
150+ output_attentions = True ,
151+ )
152+
153+ logits = outputs .logits [:, - 1 :, :]
154+ attention_scores = outputs .attentions [- 1 ]
155+ past_key_values = outputs .past_key_values
156+
157+ entropy , varentropy = calculate_varentropy_logsoftmax (logits )
158+ attention_metrics = calculate_attention_metrics (attention_scores )
159+ metrics = {
160+ "logits_entropy" : entropy ,
161+ "logits_varentropy" : varentropy ,
162+ ** attention_metrics
163+ }
164+
165+ logging .debug (f"Metrics: entropy={ entropy .item ():.3f} , varentropy={ varentropy .item ():.3f} " )
166+
167+ if entropy < 0.1 and varentropy < 0.1 :
168+ next_token = torch .argmax (logits [:, - 1 ], dim = - 1 , keepdim = True ).to (torch .int32 )
169+ logging .debug ("Using greedy sampling" )
170+ elif entropy > 3.0 and varentropy < 0.1 :
171+ if not torch .isin (gen_tokens [:,- 1 ], torch .tensor ([2564 ], device = device )).any ():
172+ next_token = torch .tensor ([[2564 ]], dtype = torch .int32 , device = device )
173+ logging .debug ("Inserting clarification token" )
174+ else :
175+ temp_adj = 1.3 + 0.2 * attention_metrics ["attn_entropy" ]
176+ next_token = _sample (logits , temperature = min (1.5 , temperature * temp_adj ), top_p = top_p , top_k = top_k , min_p = min_p , generator = generator )
177+ logging .debug (f"Using adjusted temperature sampling: { temp_adj :.3f} " )
178+ elif entropy < 5.0 and varentropy > 5.0 :
179+ temp_adj = 1.2 + 0.3 * attention_metrics ["interaction_strength" ]
180+ top_k_adj = max (5 , int (top_k * (1 + 0.5 * (1 - attention_metrics ["agreement" ]))))
181+ next_token = _sample (logits , temperature = min (1.5 , temperature * temp_adj ), top_p = top_p , top_k = top_k_adj , min_p = min_p , generator = generator )
182+ logging .debug (f"Using exploration sampling: temp={ temp_adj :.3f} , top_k={ top_k_adj } " )
183+ elif entropy > 5.0 and varentropy > 5.0 :
184+ temp_adj = 2.0 + 0.5 * attention_metrics ["attn_varentropy" ]
185+ top_p_adj = max (0.5 , top_p - 0.2 * attention_metrics ["attn_entropy" ])
186+ next_token = _sample (logits , temperature = max (2.0 , temperature * temp_adj ), top_p = top_p_adj , top_k = top_k , min_p = min_p , generator = generator )
187+ logging .debug (f"Using high uncertainty sampling: temp={ temp_adj :.3f} , top_p={ top_p_adj :.3f} " )
188+ else :
189+ next_token = adaptive_sample (
190+ logits ,
191+ metrics ,
192+ gen_tokens ,
193+ n_samples = 5 ,
194+ base_temp = temperature ,
195+ base_top_p = top_p ,
196+ base_top_k = top_k ,
197+ base_min_p = min_p ,
198+ generator = generator
199+ )
200+ logging .debug ("Using adaptive sampling" )
201+
202+ generated_tokens .append (next_token .item ())
203+ gen_tokens = torch .cat ((gen_tokens , next_token ), dim = 1 )
204+ input_ids = torch .cat ([input_ids , next_token ], dim = - 1 )
205+ attention_mask = torch .cat ([attention_mask , torch .ones ((1 , 1 ), device = device , dtype = torch .long )], dim = - 1 )
206+
207+ logging .debug (f"Generated token: { tokenizer .decode ([next_token .item ()])} " )
208+
209+ if torch .isin (next_token , stop ).any ():
210+ logging .info ("Reached stop token. Ending generation." )
211+ break
212+
213+ generated_text = tokenizer .decode (generated_tokens , skip_special_tokens = True )
214+ logging .info ("Finished entropy decoding" )
215+ logging .info (f"Generated text: { generated_text } " )
216+
217+ return generated_text
218+
219+ # Usage example
220+ from transformers import AutoModelForCausalLM , AutoTokenizer
221+
222+ model_name = "Qwen/Qwen2.5-0.5B-Instruct"
223+ model = AutoModelForCausalLM .from_pretrained (model_name , attn_implementation = "eager" )
224+ tokenizer = AutoTokenizer .from_pretrained (model_name )
225+
226+ messages = [
227+ {"role" : "user" , "content" : "In a dance class of 20 students, 20% enrolled in contemporary dance, 25% of the remaining enrolled in jazz dance, and the rest enrolled in hip-hop dance. What percentage of the entire students enrolled in hip-hop dance?" }
228+ ]
229+
230+ logging .info ("Starting entropy decoding process" )
231+ result = entropy_decode (model , tokenizer , messages )
232+ print (f"Entropy Decoding Result:\n { result } " )
233+ logging .info ("Entropy decoding process completed" )
0 commit comments