@@ -48,51 +48,87 @@ public State createNewState(int batchsize) {
48
48
49
49
@ Override
50
50
public void runInteractive (Sampler sampler , Options options ) {
51
- MistralTokenizer mistralTokenizer = getAsMistralTokenizer ();
52
51
State state = null ;
53
- MistralChatFormat chatFormat = new MistralChatFormat (getAsMistralTokenizer ());
54
52
List <Integer > conversationTokens = new ArrayList <>();
53
+
54
+ MistralChatFormat chatFormat = new MistralChatFormat (getAsMistralTokenizer ());
55
55
conversationTokens .add (chatFormat .getBeginOfText ());
56
+
56
57
int startPosition = 0 ;
57
58
Scanner in = new Scanner (System .in );
58
- while (true ) {
59
- System .out .print ("> " );
60
- System .out .flush ();
61
- if (state == null ) {
62
- // State allocation can take some time for large context sizes,
63
- // allocate the model state only after printing the user '>' prompt.
64
- state = createNewState ();
65
- }
66
- String userText = in .nextLine ();
67
- if (List .of ("quit" , "exit" ).contains (userText )) {
68
- break ;
69
- }
70
- conversationTokens .addAll (chatFormat .encodeMessage (userText , true , true ));
71
- Set <Integer > stopTokens = chatFormat .getStopTokens ();
72
- List <Integer > responseTokens = generateTokens (state , startPosition , conversationTokens .subList (startPosition , conversationTokens .size ()), stopTokens , options .maxTokens (), sampler ,
73
- options .echo (), token -> {
74
- if (options .stream ()) {
75
- int tokenType = mistralTokenizer .getTokenType (token );
76
- if (tokenType == 1 || tokenType == 6 ) {
77
- System .out .print (mistralTokenizer .decode (List .of (token )));
78
- }
59
+
60
+ // Initialize TornadoVM plan once at the beginning if GPU path is enabled
61
+ TornadoVMMasterPlan tornadoVMPlan = null ;
62
+
63
+ try {
64
+ while (true ) {
65
+ System .out .print ("> " );
66
+ System .out .flush ();
67
+ String userText = in .nextLine ();
68
+ if (List .of ("quit" , "exit" ).contains (userText )) {
69
+ break ;
70
+ }
71
+ if (state == null ) {
72
+ // State allocation can take some time for large context sizes,
73
+ // allocate the model state only after printing the user '>' prompt.
74
+ state = createNewState ();
75
+ }
76
+
77
+ if (USE_TORNADOVM && tornadoVMPlan == null ) {
78
+ tornadoVMPlan = TornadoVMMasterPlan .initializeTornadoVMPlan (state , this );
79
+ }
80
+
81
+ conversationTokens .addAll (chatFormat .encodeMessage (userText , true , true ));
82
+ Set <Integer > stopTokens = chatFormat .getStopTokens ();
83
+
84
+ List <Integer > responseTokens ;
85
+ IntConsumer tokenConsumer = token -> {
86
+ if (options .stream ()) {
87
+ if (!tokenizer .isSpecialToken (token )) {
88
+ System .out .print (tokenizer .decode (List .of (token )));
79
89
}
80
- });
81
- // Include stop token in the prompt history, but not in the response displayed to the user.
82
- conversationTokens .addAll (responseTokens );
83
- startPosition = conversationTokens .size ();
84
- Integer stopToken = null ;
85
- if (!responseTokens .isEmpty () && stopTokens .contains (responseTokens .getLast ())) {
86
- stopToken = responseTokens .getLast ();
87
- responseTokens .removeLast ();
88
- }
89
- if (!options .stream ()) {
90
- String responseText = mistralTokenizer .decode (responseTokens );
91
- System .out .println (responseText );
90
+ }
91
+ };
92
+
93
+ // Choose between GPU and CPU path based on configuration
94
+ if (USE_TORNADOVM ) {
95
+ // GPU path using TornadoVM
96
+ responseTokens = InferenceEngine .generateTokensGPU (this , state , startPosition ,
97
+ conversationTokens .subList (startPosition , conversationTokens .size ()), stopTokens ,
98
+ options .maxTokens (), sampler , options .echo (), options .stream () ? tokenConsumer : null , tornadoVMPlan );
99
+ } else {
100
+ // CPU path
101
+ responseTokens = InferenceEngine .generateTokens (this , state , startPosition ,
102
+ conversationTokens .subList (startPosition , conversationTokens .size ()), stopTokens ,
103
+ options .maxTokens (), sampler , options .echo (), tokenConsumer );
104
+ }
105
+
106
+ // Include stop token in the prompt history, but not in the response displayed to the user.
107
+ conversationTokens .addAll (responseTokens );
108
+ startPosition = conversationTokens .size ();
109
+ Integer stopToken = null ;
110
+ if (!responseTokens .isEmpty () && stopTokens .contains (responseTokens .getLast ())) {
111
+ stopToken = responseTokens .getLast ();
112
+ responseTokens .removeLast ();
113
+ }
114
+ if (!options .stream ()) {
115
+ String responseText = tokenizer .decode (responseTokens );
116
+ System .out .println (responseText );
117
+ }
118
+ if (stopToken == null ) {
119
+ System .err .println ("Ran out of context length...\n Increase context length with by passing to llama-tornado --max-tokens XXX" );
120
+ break ;
121
+ }
122
+ System .out .print ("\n " );
92
123
}
93
- if (stopToken == null ) {
94
- System .err .println ("Ran out of context length..." );
95
- break ;
124
+ } finally {
125
+ // Clean up TornadoVM resources when exiting the chat loop
126
+ if (USE_TORNADOVM && tornadoVMPlan != null ) {
127
+ try {
128
+ tornadoVMPlan .freeTornadoExecutionPlan ();
129
+ } catch (Exception e ) {
130
+ System .err .println ("Error while cleaning up TornadoVM resources: " + e .getMessage ());
131
+ }
96
132
}
97
133
}
98
134
}
@@ -101,8 +137,11 @@ public void runInteractive(Sampler sampler, Options options) {
101
137
public void runInstructOnce (Sampler sampler , Options options ) {
102
138
State state = createNewState ();
103
139
MistralChatFormat chatFormat = new MistralChatFormat (getAsMistralTokenizer ());
140
+ TornadoVMMasterPlan tornadoVMPlan = null ;
141
+
104
142
List <Integer > promptTokens = new ArrayList <>();
105
143
promptTokens .add (chatFormat .getBeginOfText ());
144
+
106
145
if (options .suffix () != null ) {
107
146
promptTokens .addAll (chatFormat .encodeFillInTheMiddle (options .prompt (), options .suffix ()));
108
147
} else {
@@ -120,7 +159,6 @@ public void runInstructOnce(Sampler sampler, Options options) {
120
159
}
121
160
};
122
161
123
- TornadoVMMasterPlan tornadoVMPlan = null ;
124
162
if (USE_TORNADOVM ) {
125
163
tornadoVMPlan = TornadoVMMasterPlan .initializeTornadoVMPlan (state , this );
126
164
// Call generateTokensGPU without the token consumer parameter
@@ -140,5 +178,9 @@ public void runInstructOnce(Sampler sampler, Options options) {
140
178
}
141
179
142
180
LastRunMetrics .printMetrics ();
181
+
182
+ if (tornadoVMPlan != null ) {
183
+ tornadoVMPlan .freeTornadoExecutionPlan ();
184
+ }
143
185
}
144
186
}
0 commit comments