Skip to content

Conversation

@MaggotHATE
Copy link
Contributor

@MaggotHATE MaggotHATE commented Oct 25, 2024

K-Shift is a sampling strategy mentioned in a Chain-of-Thought Reasoning without Prompting paper and is meant to guide models away from the most obvious startup in inference by cutting out a defined amount of tokens once at the start of the dialog. Since all the rest tokens are not being affected by the sampler, the output is still coherent. K-Shift is intended to be used with greedy sampling, and it is claimed to help with steering models more towards reasoning instead of short answers.

Since a recent commit changed how greedy sampling is achieved, this sampler fits in the main sampling queue and can be combined with top_k = 1 setting. In my experience it helped with getting different reasoning, less cliched starts in creative writing and can even change bias of the model - reducing or inducing refusals.

Examples with Mistral-Nemo-Instruct-2407.q5_k_l:
k = 0

[INST]Kylar went to the store to buy glasses for his new apartment. One glass costs $5, but every second glass costs only 60% of the price. Kylar wants to buy 16 glasses. How much does he need to pay for them?[/INST]Let's calculate the total cost step by step.

First, we need to determine how many glasses will be at the full price and how many will be at the discounted price.

Since every second glass is at a discount, we can divide the total number of glasses by 2 to find out how many will be discounted.

16 glasses / 2 = 8 glasses at full price
8 glasses / 2 = 4 glasses at discounted price

Now, let's calculate the cost for each group:

Full price glasses: 8 glasses * $5/glass = $40
Discounted glasses: 4 glasses * (60% of $5) = 4 glasses * $3/glass = $12

Finally, we add the cost of the full price glasses and the discounted glasses to get the total cost:

Total cost = Cost of full price glasses + Cost of discounted glasses
Total cost = $40 + $12
Total cost = $52

Kylar needs to pay $52 for the 16 glasses.[INST]

k = 5

[INST]Kylar went to the store to buy glasses for his new apartment. One glass costs $5, but every second glass costs only 60% of the price. Kylar wants to buy 16 glasses. How much does he need to pay for them?[/INST]Every second glass costs 60% of the price of $5, which is 0.60 * $5 = $3.

So, for every two glasses, Kylar pays $5 + $3 = $8.

Since Kylar wants to buy 16 glasses, we can divide this number by 2 to find out how many pairs of glasses he will buy:

16 glasses / 2 = 8 pairs of glasses.

Now, we multiply the cost of one pair of glasses by the number of pairs:

8 pairs * $8/pair = $64.

Therefore, Kylar needs to pay $64 for the 16 glasses.[INST]

k = 14

[INST]Kylar went to the store to buy glasses for his new apartment. One glass costs $5, but every second glass costs only 60% of the price. Kylar wants to buy 16 glasses. How much does he need to pay for them?[/INST]Here's how we can calculate the total cost for Kylar:

1. Kylar buys 16 glasses in total.
2. Every second glass costs 60% of the price of the first glass. So, for every two glasses, the cost is $5 (for the first glass) + 60% of $5 (for the second glass).
3. Let's calculate the cost for two glasses:
   - First glass: $5
   - Second glass: 60% of $5 = 0.60 * $5 = $3
   - Total for two glasses: $5 + $3 = $8
4. Now, we need to find out how many pairs of two glasses Kylar buys and how many single glasses he buys:
   - Since 16 is an even number, Kylar can buy 8 pairs of two glasses.
   - For the 8 pairs, the cost is 8 * $8 = $64.
5. Kylar doesn't need to buy any single glasses because 16 is an even number and he can buy 8 pairs of two glasses.
6. Therefore, the total cost for Kylar is $64.

Kylar needs to pay $64 for the 16 glasses.[INST]

This sampler is still in testing, but it feels like a good improvement to sampling overall - however, every model might need its own value for k. With K-Shift and XTC, greedy sampling might become useful even for creative sampling.

@MaggotHATE MaggotHATE marked this pull request as ready for review October 25, 2024 19:37
@github-actions github-actions bot added the testing Everything test related label Oct 26, 2024
@MaggotHATE
Copy link
Contributor Author

@p-e-w May I ask you for a review and maybe testing even, please? While this sampler is very simple by itself, it has quite a strong effect in practice and can be useful as an additional control measure in creative sampling. I've tested K-Shift with and without XTC, and it looks like they can work together quite nicely - just need to keep in mind how far the first cutout may go.

@p-e-w
Copy link

p-e-w commented Oct 28, 2024

I am currently sick and will be off the computer for a few days, but I intend to do a full review of this interesting PR soon.

@MaggotHATE
Copy link
Contributor Author

I am currently sick and will be off the computer for a few days

Get well! In meantime I will be testing K-Shift further to gather more data on different models (tested Nemo/Mistral Small/Gemma 2 - all behave differently so far).

@qnixsynapse
Copy link
Collaborator

@MaggotHATE Hi. Just a doubt. This sampler works like this right? Select top n tokens at the beginning then do greedy decode for each one of them and select the beam with the highest probability? Will this increase the decode time or streaming needs to be disabled? Or alternative beams can be decoded in parallel? I did try to read the code but I am not too much familiar with llamacpp api so, I end up asking. Please pardon my ignorance.

For example, in the paper, this is with k=5
image

@MaggotHATE
Copy link
Contributor Author

@qnixsynapse

Or alternative beams can be decoded in parallel?

There are no alternative beams in K-Shift - that would be CoT-decoding, the main subject of this paper. K-Shift is simply choosing the exact path at the start of inference ("Decoding step 0") by cutting out k top tokens. No further evaluation of "confidence" is performed for now as it is outside of this simple solution. It also means that K-Shift is easier to combine with existing samplers.

I plan on implementing CoT-decoding in a different sampler, but I imagine it would be quite a bulky solution within llama.cpp.

@MaggotHATE
Copy link
Contributor Author

MaggotHATE commented Oct 31, 2024

I just realized that adding ctx->k_set = false into llama_sampler_k_shift_reset changed how this sampler was supposed to work, as common_sampler_reset is called after each message in main. However, I'm not sure if the same issue would happen for server.

@slaren Is common_sampler_reset supposed to be called each time before inference, like here? If yes, I will remove custom llama_sampler_k_shift_reset, but it would be nice to be able to reset it some other way for messages regeneration (it's still not a feature in llama.cpp, but just in case).

@slaren
Copy link
Member

slaren commented Oct 31, 2024

I don't know, I am not sure when that was added, but I think it makes sense. What's the downside of resetting the sampler state after each message? I would think that you wouldn't want to apply the repetition penalties etc of the previous message to the next message. cc @ggerganov

@MaggotHATE
Copy link
Contributor Author

What's the downside of resetting the sampler state after each message?

The only downside I see is tracking switches/states within samplers themselves in cases when a sampler should be applied once (like K-Shift, for example). On reset, either the sampler will be applied again, or, without custom reset function, we won't be able to revert the switch without deleting sampler object.

@slaren
Copy link
Member

slaren commented Nov 1, 2024

Removing the reset function is definitely not correct here, since it breaks the assumption that samplers can be re-used with different sequences by resetting them. I have looked at the paper that you linked and I think that what when they say "first decoding step" it could be very well interpreted as the first token of each response in a conversation.

However, what the paper is talking about cannot be implemented with a sampler alone. The paper is talking about generating k different sequences for the response, each starting with a different token, and then aggregating the results. That would be interesting to implement in an example as a proof of concept, but as it is, I don't think that this sampler would be useful by itself without the rest of the algorithm. A bonus would be implementing this using multiple sequences to generate all the responses at the same time in parallel.

@MaggotHATE
Copy link
Contributor Author

MaggotHATE commented Nov 1, 2024

I think that what when they say "first decoding step" it could be very well interpreted as the first token of each response in a conversation.

Alright, I will revert it back then. In recent tests it was still coherent even with reset. Although, it would be nice to have a way to trigger it once per session. Is it even possible in the current samplers chain implementation?

I don't think that this sampler would be useful by itself without the rest of the algorithm.

I've tested it in practice, and it actually works quite well by itself. In a way, it works similarly to XTC, but under more strict conditions. That alone makes K-Shift more compatible with greedy sampling.

As for the main method in the paper, it is interesting, but it will likely become another example app with no prospects of being in main or server. K-Shift is more practical, even if (or maybe because) it's simple and straightforward.

Copy link

@p-e-w p-e-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The technique described in the paper is indeed very promising, but I have to agree with @slaren: This sampler is not very useful without the algorithmic framework outlined by the paper.

The problem is that unlike XTC, K-Shift sampling fails to make any guarantees about token probabilities. The 5th-highest token might have a probability of 7.3% (in which case it could be an interesting path to explore) or 0.000000032% (in which case it is almost certainly garbage), and the sampler cannot distinguish between the two. Thus this is a trial-and-error sampler for which you have to do trial-and-error every time you use it.

To make K-Shift sampling useful, it needs either

  • an implementation of the confidence-maximizing beam search strategy described in the paper, or
  • additional parameters that allow it to take probability magnitudes into account.

As it stands, all this sampler does is choose the first word of the output for you, which you can do yourself by just adding it to the input, in about as much time as it takes to fiddle with the parameter.


if (ctx->k_set == true
|| ctx->k <= 0
|| ctx->k >= (int) cur_p->size) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There appears to be a bug here: If at the first token position, ctx->k >= (int) cur_p->size (e.g. because a preceding truncation sampler has already removed too many tokens) then we return and k_set remains false. This means that K-Shift will only take effect on the second (or later) token of the output, violating its contract.


// shift to a token #[k]
cur_p->data += k;
cur_p->size -= k;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not match the paper, in which (AFAICT) exactly the k-th token is selected, rather than sampling from all tokens except the top k-1 ones.

Copy link
Contributor Author

@MaggotHATE MaggotHATE Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The choice of token will be handled by the final step in sampling queue, so greedy sampling would be needed to match the effect described in paper. Considering that specifically greedy sampler was removed recently, I don't think introducing another "final step" sampler would be ok.

@MaggotHATE
Copy link
Contributor Author

As it stands, all this sampler does is choose the first word of the output for you,

...which is something no other sampler can do for now.

which you can do yourself by just adding it to the input, in about as much time as it takes to fiddle with the parameter.

The problem is that whatever you add into the input might not match the natural flow of things for the model: you would either have to look at candidates presented by the model at the first step, or to just force something you need while ignoring the candidates.

Having worked with this for a while now, I see benefits from K-Shift as a simple guidance that is more interesting and effective than simply adjusting suffixes (or adjusting completion results) every time you need a specific start of the output. It's good to have a guaranteed result, but doesn't always work as we want - otherwise the paper wouldn't exist, and CoT instructions would've have a guaranteed effect.

I think I'll look at probabilities control though, it might be interesting.

* the parameter limits how far K-Shift cuts by checking probability of the last token and iterating backwards if it's not probable enough
@MaggotHATE
Copy link
Contributor Author

@p-e-w I've added shift_p_min parameter to set the minimum required probability. Interestingly, the value at which it's still coherent might not be as high as we imagine. For example:

Tested on MS-Meadowlark-22B-Q4_K_L, K-Shift -> top_k = 1 (for greedy sampling).

Without limits, k = 50 (extreme case) :

<s>[INST] What's a paladin?[/INST] Hmm, I'm not sure I understand your question. Could you please provide more context or clarify what you mean by "paladin"?[INST] 

Candidates logged for statistics (first 8, probability|logit):

(32768):  #0[0.892989|18.359356] #1[0.068965|15.798388] #2[0.007043|13.516762] #3[0.005626|13.292268] #4[0.005137|13.201305] #5[0.002399|12.439665] #6[0.002283|12.390135] #7[0.002116|12.314234]
 Will cut at 0.000044

The answer is hilarious, yet still coherent in it's own way.

The same, but with a limit of 0.0001:

<s>[INST] What's a paladin?[/INST] ["PAL-uh-din"]  A paladin is a type of character often found in fantasy settings, particularly in role-playing games and literature. Here are some key characteristics of a paladin:

1. **Holy Warrior**: Paladins are typically warriors who serve a deity or a cause of good. They are often associated with religious orders or knightly organizations.

2. **Code of Conduct**: Paladins usually adhere to a strict code of conduct that emphasizes honor, chivalry, and righteousness. This code often includes rules about how they should behave in battle, how they should treat others, and how they should uphold justice.

3. **Magical Abilities**: In addition to their combat skills, paladins often have access to divine magic. This can include healing spells, protective spells, and spells that deal damage to enemies.

4. **Special Abilities**: Paladins may have unique abilities that set them apart from other warriors. These can include the ability to detect evil, the ability to turn undead creatures, or the ability to smite enemies with divine power.

5. **Symbol of Hope**: Paladins are often seen as symbols of hope and justice. They are expected to uphold the highest standards of behavior and to inspire others to do the same.

6. **Equipment**: Paladins are often depicted as wearing heavy armor and carrying a shield and a weapon, such as a sword or a mace. Their equipment may also include holy symbols or other items that represent their faith.

In many stories, paladins are portrayed as selfless heroes who are willing to make great sacrifices for the sake of their cause. However, they can also be portrayed as overzealous or even fanatical, depending on the story and the character's personal beliefs.[INST] 

Candidates:

(32768):  #0[0.892989|18.359356] #1[0.068965|15.798388] #2[0.007043|13.516762] #3[0.005626|13.292268] #4[0.005137|13.201305] #5[0.002399|12.439665] #6[0.002283|12.390135] #7[0.002116|12.314234]
 Will cut #28 at 0.000104

While the result is not optimal, it still doesn't drift away into hallucinations because only the first choice is affected.

Going back to your suggestion of adding the needed word to the start of the output as a guidance: it works well together with K-Shift, actually. I've tested it with completion (which would be technically the same) on llama-cli, and it was even better than I anticipated.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples server testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants