@@ -63,6 +63,33 @@ static void llama_log_softmax(float * array, size_t size) {
6363} 
6464*/ 
6565
66+ static  void  llama_sampler_temp_impl (llama_token_data_array * cur_p, float  temp) {
67+     if  (temp <= 0 .0f ) {
68+         //  find the token with the highest logit and set the rest to -inf
69+         llama_token max_id = cur_p->data [0 ].id ;
70+         float  max_logit = cur_p->data [0 ].logit ;
71+ 
72+         for  (size_t  i = 1 ; i < cur_p->size ; ++i) {
73+             if  (cur_p->data [i].logit  > max_logit) {
74+                 max_id    = cur_p->data [i].id ;
75+                 max_logit = cur_p->data [i].logit ;
76+             }
77+         }
78+ 
79+         for  (size_t  i = 0 ; i < cur_p->size ; ++i) {
80+             if  (cur_p->data [i].id  != max_id) {
81+                 cur_p->data [i].logit  = -INFINITY;
82+             }
83+         }
84+ 
85+         return ;
86+     }
87+ 
88+     for  (size_t  i = 0 ; i < cur_p->size ; ++i) {
89+         cur_p->data [i].logit  /= temp;
90+     }
91+ }
92+ 
6693static  void  llama_sampler_softmax_impl (llama_token_data_array * cur_p) {
6794    GGML_ASSERT (cur_p->size  > 0 );
6895
@@ -916,30 +943,7 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*
916943static  void  llama_sampler_temp_apply (struct  llama_sampler  * smpl, llama_token_data_array * cur_p) {
917944    const  auto  * ctx = (llama_sampler_temp *) smpl->ctx ;
918945
919-     if  (ctx->temp  <= 0 .0f ) {
920-         //  find the token with the highest logit and set the rest to -inf
921-         llama_token max_id = cur_p->data [0 ].id ;
922-         float  max_logit = cur_p->data [0 ].logit ;
923- 
924-         for  (size_t  i = 1 ; i < cur_p->size ; ++i) {
925-             if  (cur_p->data [i].logit  > max_logit) {
926-                 max_id    = cur_p->data [i].id ;
927-                 max_logit = cur_p->data [i].logit ;
928-             }
929-         }
930- 
931-         for  (size_t  i = 0 ; i < cur_p->size ; ++i) {
932-             if  (cur_p->data [i].id  != max_id) {
933-                 cur_p->data [i].logit  = -INFINITY;
934-             }
935-         }
936- 
937-         return ;
938-     }
939- 
940-     for  (size_t  i = 0 ; i < cur_p->size ; ++i) {
941-         cur_p->data [i].logit  /= ctx->temp ;
942-     }
946+     llama_sampler_temp_impl (cur_p, ctx->temp );
943947}
944948
945949static  struct  llama_sampler  * llama_sampler_temp_clone (const  struct  llama_sampler  * smpl) {
@@ -1024,9 +1028,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
10241028    #endif 
10251029
10261030        //  Apply the dynamically calculated temperature scaling
1027-         for  (size_t  i = 0 ; i < cur_p->size ; ++i) {
1028-             cur_p->data [i].logit  /= dyn_temp;
1029-         }
1031+         llama_sampler_temp_impl (cur_p, dyn_temp);
10301032
10311033        //  Re-compute softmax probabilities after scaling logits with dynamic temperature
10321034        const  double  max_l_double = cur_p->data [0 ].logit ;
@@ -1050,9 +1052,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke
10501052        }
10511053    #endif 
10521054    } else  {
1053-         for  (size_t  i = 0 ; i < cur_p->size ; ++i) {
1054-             cur_p->data [i].logit  /= ctx->temp ;
1055-         }
1055+         llama_sampler_temp_impl (cur_p, ctx->temp );
10561056    }
10571057}
10581058
0 commit comments