|
1 | 1 | package com.example;
|
2 | 2 |
|
3 | 3 | import com.example.aot.AOT;
|
4 |
| -import com.example.auxiliary.ChatFormat; |
5 | 4 | import com.example.core.model.tensor.FloatTensor;
|
6 |
| -import com.example.inference.CategoricalSampler; |
7 |
| -import com.example.inference.Sampler; |
8 |
| -import com.example.inference.ToppSampler; |
9 |
| -import com.example.inference.engine.impl.Llama; |
10 |
| -import com.example.inference.engine.impl.Options; |
| 5 | +import com.example.inference.sampler.CategoricalSampler; |
| 6 | +import com.example.inference.sampler.Sampler; |
| 7 | +import com.example.inference.sampler.ToppSampler; |
| 8 | +import com.example.model.Model; |
11 | 9 | import com.example.loader.weights.ModelLoader;
|
12 |
| -import com.example.loader.weights.State; |
13 | 10 | import com.example.tornadovm.FloatArrayUtils;
|
14 |
| -import com.example.tornadovm.TornadoVMMasterPlan; |
15 | 11 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
|
16 | 12 |
|
17 | 13 | import java.io.IOException;
|
18 |
| -import java.util.ArrayList; |
19 |
| -import java.util.List; |
20 |
| -import java.util.Scanner; |
21 |
| -import java.util.Set; |
22 |
| -import java.util.function.IntConsumer; |
23 | 14 | import java.util.random.RandomGenerator;
|
24 | 15 | import java.util.random.RandomGeneratorFactory;
|
25 | 16 |
|
@@ -115,156 +106,20 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp,
|
115 | 106 | return sampler;
|
116 | 107 | }
|
117 | 108 |
|
118 |
| - static void runInteractive(Llama model, Sampler sampler, Options options) { |
119 |
| - State state = null; |
120 |
| - List<Integer> conversationTokens = new ArrayList<>(); |
121 |
| - ChatFormat chatFormat = new ChatFormat(model.tokenizer()); |
122 |
| - conversationTokens.add(chatFormat.beginOfText); |
123 |
| - if (options.systemPrompt() != null) { |
124 |
| - conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); |
125 |
| - } |
126 |
| - int startPosition = 0; |
127 |
| - Scanner in = new Scanner(System.in); |
128 |
| - |
129 |
| - // Initialize TornadoVM plan once at the beginning if GPU path is enabled |
130 |
| - TornadoVMMasterPlan tornadoVMPlan = null; |
131 |
| - |
132 |
| - try { |
133 |
| - while (true) { |
134 |
| - System.out.print("> "); |
135 |
| - System.out.flush(); |
136 |
| - String userText = in.nextLine(); |
137 |
| - if (List.of("quit", "exit").contains(userText)) { |
138 |
| - break; |
139 |
| - } |
140 |
| - if (state == null) { |
141 |
| - state = model.createNewState(); |
142 |
| - } |
143 |
| - |
144 |
| - if (USE_TORNADOVM && tornadoVMPlan == null) { |
145 |
| - tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model); |
146 |
| - } |
147 |
| - |
148 |
| - conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText))); |
149 |
| - conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); |
150 |
| - Set<Integer> stopTokens = chatFormat.getStopTokens(); |
151 |
| - |
152 |
| - List<Integer> responseTokens; |
153 |
| - IntConsumer tokenConsumer = token -> { |
154 |
| - if (options.stream()) { |
155 |
| - if (!model.tokenizer().isSpecialToken(token)) { |
156 |
| - System.out.print(model.tokenizer().decode(List.of(token))); |
157 |
| - } |
158 |
| - } |
159 |
| - }; |
160 |
| - |
161 |
| - // Choose between GPU and CPU path based on configuration |
162 |
| - if (USE_TORNADOVM) { |
163 |
| - // GPU path using TornadoVM |
164 |
| - responseTokens = Llama.generateTokensGPU(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), |
165 |
| - sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan); |
166 |
| - } else { |
167 |
| - // CPU path |
168 |
| - responseTokens = Llama.generateTokens(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler, |
169 |
| - options.echo(), tokenConsumer); |
170 |
| - } |
171 |
| - |
172 |
| - // Include stop token in the prompt history, but not in the response displayed to the user. |
173 |
| - conversationTokens.addAll(responseTokens); |
174 |
| - startPosition = conversationTokens.size(); |
175 |
| - Integer stopToken = null; |
176 |
| - if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { |
177 |
| - stopToken = responseTokens.getLast(); |
178 |
| - responseTokens.removeLast(); |
179 |
| - } |
180 |
| - if (!options.stream()) { |
181 |
| - String responseText = model.tokenizer().decode(responseTokens); |
182 |
| - System.out.println(responseText); |
183 |
| - } |
184 |
| - if (stopToken == null) { |
185 |
| - System.err.println("\n Ran out of context length...\n Increase context length with by passing to llama-tornado --max-tokens XXX"); |
186 |
| - break; |
187 |
| - } |
188 |
| - System.out.print("\n"); |
189 |
| - |
190 |
| - // Optionally print performance metrics after each response |
191 |
| - if (SHOW_PERF_INTERACTIVE) { |
192 |
| - Llama.LastRunMetrics.printMetrics(); |
193 |
| - } |
194 |
| - } |
195 |
| - } finally { |
196 |
| - // Clean up TornadoVM resources when exiting the chat loop |
197 |
| - if (USE_TORNADOVM && tornadoVMPlan != null) { |
198 |
| - try { |
199 |
| - tornadoVMPlan.freeTornadoExecutionPlan(); |
200 |
| - } catch (Exception e) { |
201 |
| - System.err.println("Error while cleaning up TornadoVM resources: " + e.getMessage()); |
202 |
| - } |
203 |
| - } |
204 |
| - } |
205 |
| - } |
206 |
| - |
207 |
| - static void runInstructOnce(Llama model, Sampler sampler, Options options) { |
208 |
| - State state = model.createNewState(); |
209 |
| - ChatFormat chatFormat = new ChatFormat(model.tokenizer()); |
210 |
| - TornadoVMMasterPlan tornadoVMPlan = null; |
211 |
| - |
212 |
| - List<Integer> promptTokens = new ArrayList<>(); |
213 |
| - promptTokens.add(chatFormat.beginOfText); |
214 |
| - if (options.systemPrompt() != null) { |
215 |
| - promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); |
216 |
| - } |
217 |
| - promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt()))); |
218 |
| - promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); |
219 |
| - List<Integer> responseTokens; |
220 |
| - |
221 |
| - // Define the token consumer |
222 |
| - IntConsumer tokenConsumer = token -> { |
223 |
| - if (options.stream()) { |
224 |
| - if (!model.tokenizer().isSpecialToken(token)) { |
225 |
| - System.out.print(model.tokenizer().decode(List.of(token))); |
226 |
| - } |
227 |
| - } |
228 |
| - }; |
229 |
| - |
230 |
| - Set<Integer> stopTokens = chatFormat.getStopTokens(); |
231 |
| - if (USE_TORNADOVM) { |
232 |
| - tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model); |
233 |
| - // Call generateTokensGPU without the token consumer parameter |
234 |
| - responseTokens = Llama.generateTokensGPU(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan); |
235 |
| - } else { |
236 |
| - // CPU path still uses the token consumer |
237 |
| - responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer); |
238 |
| - } |
239 |
| - |
240 |
| - if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { |
241 |
| - responseTokens.removeLast(); |
242 |
| - } |
243 |
| - if (!options.stream()) { |
244 |
| - String responseText = model.tokenizer().decode(responseTokens); |
245 |
| - System.out.println(responseText); |
246 |
| - } |
247 |
| - |
248 |
| - Llama.LastRunMetrics.printMetrics(); |
249 |
| - |
250 |
| - if (tornadoVMPlan != null) { |
251 |
| - tornadoVMPlan.freeTornadoExecutionPlan(); |
252 |
| - } |
253 |
| - } |
254 |
| - |
255 | 109 | public static void main(String[] args) throws IOException {
|
256 | 110 | Options options = Options.parseOptions(args);
|
257 |
| - Llama model; |
| 111 | + Model model; |
258 | 112 | if (USE_AOT) {
|
259 | 113 | model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens());
|
260 | 114 | } else {
|
261 | 115 | model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true);
|
262 | 116 | }
|
263 |
| - Sampler sampler = selectSampler(model.configuration().vocabularySize, options.temperature(), options.topp(), options.seed()); |
| 117 | + assert model != null; |
| 118 | + Sampler sampler = selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed()); |
264 | 119 | if (options.interactive()) {
|
265 |
| - runInteractive(model, sampler, options); |
| 120 | + model.runInteractive(sampler, options); |
266 | 121 | } else {
|
267 |
| - runInstructOnce(model, sampler, options); |
| 122 | + model.runInstructOnce(sampler, options); |
268 | 123 | }
|
269 | 124 | }
|
270 | 125 | }
|
|
0 commit comments