@@ -179,7 +179,7 @@ static llama_token llama_sampling_sample_impl(
179179 struct llama_context * ctx_main,
180180 struct llama_context * ctx_cfg,
181181 const int idx,
182- bool is_resampling) { // Add a parameter to indicate if we are resampling
182+ bool is_resampling) {
183183 const llama_sampling_params & params = ctx_sampling->params ;
184184
185185 const float temp = params.temp ;
@@ -188,8 +188,8 @@ static llama_token llama_sampling_sample_impl(
188188 const float mirostat_eta = params.mirostat_eta ;
189189
190190 std::vector<float > original_logits;
191- auto cur_p = llama_sampling_prepare (ctx_sampling, ctx_main, ctx_cfg, idx, ! is_resampling, &original_logits);
192- if (!is_resampling) {
191+ auto cur_p = llama_sampling_prepare (ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
192+ if (ctx_sampling-> grammar != NULL && !is_resampling) {
193193 GGML_ASSERT (!original_logits.empty ());
194194 }
195195 llama_token id = 0 ;
@@ -252,7 +252,7 @@ static llama_token llama_sampling_sample_impl(
252252 // Restore logits from the copy
253253 std::copy (original_logits.begin (), original_logits.end (), logits);
254254
255- return llama_sampling_sample_impl (ctx_sampling, ctx_main, ctx_cfg, idx, true ); // Pass true for is_resampling
255+ return llama_sampling_sample_impl (ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true );
256256 }
257257 }
258258
@@ -285,7 +285,8 @@ static llama_token_data_array llama_sampling_prepare_impl(
285285 // Get a pointer to the logits
286286 float * logits = llama_get_logits_ith (ctx_main, idx);
287287
288- if (apply_grammar && original_logits != NULL ) {
288+ if (ctx_sampling->grammar != NULL && !apply_grammar) {
289+ GGML_ASSERT (original_logits != NULL );
289290 // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
290291 *original_logits = {logits, logits + llama_n_vocab (llama_get_model (ctx_main))};
291292 }
@@ -342,7 +343,7 @@ llama_token llama_sampling_sample(
342343 struct llama_context * ctx_cfg,
343344 const int idx) {
344345 // Call the implementation function with is_resampling set to false by default
345- return llama_sampling_sample_impl (ctx_sampling, ctx_main, ctx_cfg, idx, false );
346+ return llama_sampling_sample_impl (ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false );
346347}
347348
348349llama_token_data_array llama_sampling_prepare (
0 commit comments