@@ -38,6 +38,18 @@ public interface Model {
38
38
39
39
State createNewState (int batchsize );
40
40
41
+ default boolean shouldAddBeginOfText () {
42
+ return true ;
43
+ }
44
+
45
+ default boolean shouldAddSystemPrompt () {
46
+ return true ;
47
+ }
48
+
49
+ default boolean shouldIncludeReasoning () {
50
+ return false ;
51
+ }
52
+
41
53
/**
42
54
* Wrapper for invoking the model-specific forward pass via InferenceCore.
43
55
*
@@ -68,11 +80,11 @@ default void runInteractive(Sampler sampler, Options options) {
68
80
ChatFormat chatFormat = chatFormat ();
69
81
TornadoVMMasterPlan tornadoVMPlan = null ;
70
82
71
- if (! getModelType (). equals ( ModelType . QWEN_3 ) && ! getModelType (). equals ( ModelType . PHI_3 )) {
83
+ if (shouldAddBeginOfText ( )) {
72
84
conversationTokens .add (chatFormat .getBeginOfText ());
73
85
}
74
86
75
- if (options .systemPrompt () != null ) {
87
+ if (shouldAddSystemPrompt () && options .systemPrompt () != null ) {
76
88
conversationTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .SYSTEM , options .systemPrompt ())));
77
89
}
78
90
@@ -95,6 +107,18 @@ default void runInteractive(Sampler sampler, Options options) {
95
107
96
108
conversationTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .USER , userText )));
97
109
conversationTokens .addAll (chatFormat .encodeHeader (new ChatFormat .Message (ChatFormat .Role .ASSISTANT , "" )));
110
+
111
+ // Include reasoning for Deepseek-R1-Distill-Qwen
112
+ if (shouldIncludeReasoning ()) {
113
+ List <Integer > thinkStartTokens = tokenizer ().encode ("<think>\n " , tokenizer ().getSpecialTokens ().keySet ());
114
+ conversationTokens .addAll (thinkStartTokens );
115
+
116
+ // If streaming, immediately output the think start
117
+ if (options .stream ()) {
118
+ System .out .print ("<think>\n " );
119
+ }
120
+ }
121
+
98
122
Set <Integer > stopTokens = chatFormat .getStopTokens ();
99
123
100
124
List <Integer > responseTokens ;
@@ -127,6 +151,10 @@ default void runInteractive(Sampler sampler, Options options) {
127
151
}
128
152
if (!options .stream ()) {
129
153
String responseText = tokenizer ().decode (responseTokens );
154
+ // Add the forced <think>\n prefix for non-streaming output
155
+ if (shouldIncludeReasoning ()) {
156
+ responseText = "<think>\n " + responseText ;
157
+ }
130
158
System .out .println (responseText );
131
159
}
132
160
if (stopToken == null ) {
@@ -164,11 +192,11 @@ default void runInstructOnce(Sampler sampler, Options options) {
164
192
165
193
List <Integer > promptTokens = new ArrayList <>();
166
194
167
- if (! getModelType (). equals ( ModelType . QWEN_3 ) && ! getModelType (). equals ( ModelType . QWEN_2 ) && ! getModelType (). equals ( ModelType . PHI_3 )) {
195
+ if (shouldAddBeginOfText ( )) {
168
196
promptTokens .add (chatFormat .getBeginOfText ());
169
197
}
170
198
171
- if (options .systemPrompt () != null ) {
199
+ if (shouldAddSystemPrompt () && options .systemPrompt () != null ) {
172
200
promptTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .SYSTEM , options .systemPrompt ())));
173
201
}
174
202
@@ -180,6 +208,17 @@ default void runInstructOnce(Sampler sampler, Options options) {
180
208
promptTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .USER , options .prompt ())));
181
209
promptTokens .addAll (chatFormat .encodeHeader (new ChatFormat .Message (ChatFormat .Role .ASSISTANT , "" )));
182
210
211
+ // Include reasoning for Deepseek-R1-Distill-Qwen
212
+ if (shouldIncludeReasoning ()) {
213
+ List <Integer > thinkStartTokens = tokenizer ().encode ("<think>\n " , tokenizer ().getSpecialTokens ().keySet ());
214
+ promptTokens .addAll (thinkStartTokens );
215
+
216
+ // If streaming, immediately output the think start
217
+ if (options .stream ()) {
218
+ System .out .print ("<think>\n " );
219
+ }
220
+ }
221
+
183
222
List <Integer > responseTokens ;
184
223
185
224
IntConsumer tokenConsumer = token -> {
@@ -206,6 +245,10 @@ default void runInstructOnce(Sampler sampler, Options options) {
206
245
}
207
246
if (!options .stream ()) {
208
247
String responseText = tokenizer ().decode (responseTokens );
248
+ // Add the forced <think>\n prefix for non-streaming output
249
+ if (shouldIncludeReasoning ()) {
250
+ responseText = "<think>\n " + responseText ;
251
+ }
209
252
System .out .println (responseText );
210
253
}
211
254
0 commit comments