@@ -62,10 +62,11 @@ List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> pr
62
62
* @param options
63
63
*/
64
64
default void runInteractive (Sampler sampler , Options options ) {
65
- State state = null ;
65
+ // Even though might be expensive, create state here for smoother interaction later
66
+ State state = createNewState ();
66
67
List <Integer > conversationTokens = new ArrayList <>();
67
-
68
68
ChatFormat chatFormat = chatFormat ();
69
+ TornadoVMMasterPlan tornadoVMPlan = null ;
69
70
70
71
if (!getModelType ().equals (ModelType .QWEN_3 )) {
71
72
conversationTokens .add (chatFormat .getBeginOfText ());
@@ -79,7 +80,9 @@ default void runInteractive(Sampler sampler, Options options) {
79
80
Scanner in = new Scanner (System .in );
80
81
81
82
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
82
- TornadoVMMasterPlan tornadoVMPlan = null ;
83
+ if (USE_TORNADOVM && tornadoVMPlan == null ) {
84
+ tornadoVMPlan = TornadoVMMasterPlan .initializeTornadoVMPlan (state , this );
85
+ }
83
86
84
87
try {
85
88
while (true ) {
@@ -89,15 +92,6 @@ default void runInteractive(Sampler sampler, Options options) {
89
92
if (List .of ("quit" , "exit" ).contains (userText )) {
90
93
break ;
91
94
}
92
- if (state == null ) {
93
- // State allocation can take some time for large context sizes,
94
- // allocate the model state only after printing the user '>' prompt.
95
- state = createNewState ();
96
- }
97
-
98
- if (USE_TORNADOVM && tornadoVMPlan == null ) {
99
- tornadoVMPlan = TornadoVMMasterPlan .initializeTornadoVMPlan (state , this );
100
- }
101
95
102
96
conversationTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .USER , userText )));
103
97
conversationTokens .addAll (chatFormat .encodeHeader (new ChatFormat .Message (ChatFormat .Role .ASSISTANT , "" )));
@@ -177,6 +171,12 @@ default void runInstructOnce(Sampler sampler, Options options) {
177
171
if (options .systemPrompt () != null ) {
178
172
promptTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .SYSTEM , options .systemPrompt ())));
179
173
}
174
+
175
+ // Initialize TornadoVM plan once at the beginning if GPU path is enabled
176
+ if (USE_TORNADOVM && tornadoVMPlan == null ) {
177
+ tornadoVMPlan = TornadoVMMasterPlan .initializeTornadoVMPlan (state , this );
178
+ }
179
+
180
180
promptTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .USER , options .prompt ())));
181
181
promptTokens .addAll (chatFormat .encodeHeader (new ChatFormat .Message (ChatFormat .Role .ASSISTANT , "" )));
182
182
@@ -194,7 +194,6 @@ default void runInstructOnce(Sampler sampler, Options options) {
194
194
195
195
if (USE_TORNADOVM ) {
196
196
// GPU path using TornadoVM
197
- tornadoVMPlan = TornadoVMMasterPlan .initializeTornadoVMPlan (state , this );
198
197
// Call generateTokensGPU without the token consumer parameter
199
198
responseTokens = generateTokensGPU (state , 0 , promptTokens , stopTokens , options .maxTokens (), sampler , options .echo (), options .stream () ? tokenConsumer : null , tornadoVMPlan );
200
199
} else {
0 commit comments