Skip to content

Commit 1fba5bf

Browse files
Fix reasoning management in Deepseek-R1-Distill-Qwen and Qwen models
1 parent d1239eb commit 1fba5bf

File tree

5 files changed

+91
-6
lines changed

5 files changed

+91
-6
lines changed

src/main/java/com/example/model/qwen2/Qwen2.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,32 @@ public State createNewState(int batchsize) {
5454
return state;
5555
}
5656

57+
/**
58+
* No <|beginoftext|> needed for Qwen models.
59+
*/
60+
@Override
61+
public boolean shouldAddBeginOfText() {
62+
return false;
63+
}
64+
65+
/**
66+
* No system prompt for Deepseek-R1-Distill-Qwen.
67+
* Based on <a href="https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B">Usage Recommendations</a>
68+
*/
69+
@Override
70+
public boolean shouldAddSystemPrompt() {
71+
return !getModelType().isDeepSeekR1();
72+
}
73+
74+
/**
75+
* Force inclusion of <think></think> for Deepseek-R1-Distill-Qwen.
76+
* Based on <a href="https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B">Usage Recommendations</a>
77+
*/
78+
@Override
79+
public boolean shouldIncludeReasoning() {
80+
return getModelType().isDeepSeekR1();
81+
}
82+
5783
@Override
5884
public void forward(State state, int token, int position) {
5985
if (plan == null) {

src/main/java/org/beehive/gpullama3/model/Model.java

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,18 @@ public interface Model {
3838

3939
State createNewState(int batchsize);
4040

41+
default boolean shouldAddBeginOfText() {
42+
return true;
43+
}
44+
45+
default boolean shouldAddSystemPrompt() {
46+
return true;
47+
}
48+
49+
default boolean shouldIncludeReasoning() {
50+
return false;
51+
}
52+
4153
/**
4254
* Wrapper for invoking the model-specific forward pass via InferenceCore.
4355
*
@@ -68,11 +80,11 @@ default void runInteractive(Sampler sampler, Options options) {
6880
ChatFormat chatFormat = chatFormat();
6981
TornadoVMMasterPlan tornadoVMPlan = null;
7082

71-
if (!getModelType().equals(ModelType.QWEN_3) && !getModelType().equals(ModelType.PHI_3)) {
83+
if (shouldAddBeginOfText()) {
7284
conversationTokens.add(chatFormat.getBeginOfText());
7385
}
7486

75-
if (options.systemPrompt() != null) {
87+
if (shouldAddSystemPrompt() && options.systemPrompt() != null) {
7688
conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
7789
}
7890

@@ -95,6 +107,18 @@ default void runInteractive(Sampler sampler, Options options) {
95107

96108
conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText)));
97109
conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
110+
111+
// Include reasoning for Deepseek-R1-Distill-Qwen
112+
if (shouldIncludeReasoning()) {
113+
List<Integer> thinkStartTokens = tokenizer().encode("<think>\n", tokenizer().getSpecialTokens().keySet());
114+
conversationTokens.addAll(thinkStartTokens);
115+
116+
// If streaming, immediately output the think start
117+
if (options.stream()) {
118+
System.out.print("<think>\n");
119+
}
120+
}
121+
98122
Set<Integer> stopTokens = chatFormat.getStopTokens();
99123

100124
List<Integer> responseTokens;
@@ -127,6 +151,10 @@ default void runInteractive(Sampler sampler, Options options) {
127151
}
128152
if (!options.stream()) {
129153
String responseText = tokenizer().decode(responseTokens);
154+
// Add the forced <think>\n prefix for non-streaming output
155+
if (shouldIncludeReasoning()) {
156+
responseText = "<think>\n" + responseText;
157+
}
130158
System.out.println(responseText);
131159
}
132160
if (stopToken == null) {
@@ -164,11 +192,11 @@ default void runInstructOnce(Sampler sampler, Options options) {
164192

165193
List<Integer> promptTokens = new ArrayList<>();
166194

167-
if (!getModelType().equals(ModelType.QWEN_3) && !getModelType().equals(ModelType.QWEN_2) && !getModelType().equals(ModelType.PHI_3)) {
195+
if (shouldAddBeginOfText()) {
168196
promptTokens.add(chatFormat.getBeginOfText());
169197
}
170198

171-
if (options.systemPrompt() != null) {
199+
if (shouldAddSystemPrompt() && options.systemPrompt() != null) {
172200
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
173201
}
174202

@@ -180,6 +208,17 @@ default void runInstructOnce(Sampler sampler, Options options) {
180208
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt())));
181209
promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
182210

211+
// Include reasoning for Deepseek-R1-Distill-Qwen
212+
if (shouldIncludeReasoning()) {
213+
List<Integer> thinkStartTokens = tokenizer().encode("<think>\n", tokenizer().getSpecialTokens().keySet());
214+
promptTokens.addAll(thinkStartTokens);
215+
216+
// If streaming, immediately output the think start
217+
if (options.stream()) {
218+
System.out.print("<think>\n");
219+
}
220+
}
221+
183222
List<Integer> responseTokens;
184223

185224
IntConsumer tokenConsumer = token -> {
@@ -206,6 +245,10 @@ default void runInstructOnce(Sampler sampler, Options options) {
206245
}
207246
if (!options.stream()) {
208247
String responseText = tokenizer().decode(responseTokens);
248+
// Add the forced <think>\n prefix for non-streaming output
249+
if (shouldIncludeReasoning()) {
250+
responseText = "<think>\n" + responseText;
251+
}
209252
System.out.println(responseText);
210253
}
211254

src/main/java/org/beehive/gpullama3/model/phi3/Phi3.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ public State createNewState(int batchsize) {
5353
return state;
5454
}
5555

56+
/**
57+
* No begin of text needed for Phi3 models.
58+
*/
59+
@Override
60+
public boolean shouldAddBeginOfText() {
61+
return false;
62+
}
63+
5664
@Override
5765
public void forward(State state, int token, int position) {
5866
if (plan == null) {

src/main/java/org/beehive/gpullama3/model/qwen3/Qwen3.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ public State createNewState(int batchsize) {
5353
return state;
5454
}
5555

56+
/**
57+
* No begin of text needed for Qwen models.
58+
*/
59+
@Override
60+
public boolean shouldAddBeginOfText() {
61+
return false;
62+
}
63+
5664
@Override
5765
public void forward(State state, int token, int position) {
5866
if (plan == null) {

src/main/java/org/beehive/gpullama3/tokenizer/impl/Qwen3Tokenizer.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ public boolean isSpecialToken(int tokenIndex) {
5353
@Override
5454
public boolean shouldDisplayToken(int token) {
5555
int tokenType = getTokenType(token);
56-
57-
return tokenType == 1 || tokenType == 6;
56+
// tokenType 4 allows the display of reasoning ( <think> ... <\think> )
57+
return tokenType == 1 || tokenType == 4 || tokenType == 6;
5858
}
5959

6060
public int getTokenType(int tokenIndex) {

0 commit comments

Comments
 (0)