|
13 | 13 | import java.util.List;
|
14 | 14 | import java.util.Scanner;
|
15 | 15 | import java.util.Set;
|
| 16 | +import java.util.function.Consumer; |
16 | 17 | import java.util.function.IntConsumer;
|
17 | 18 |
|
18 | 19 | import static org.beehive.gpullama3.LlamaApp.SHOW_PERF_INTERACTIVE;
|
@@ -257,4 +258,78 @@ default String runInstructOnce(Sampler sampler, Options options) {
|
257 | 258 |
|
258 | 259 | return responseText;
|
259 | 260 | }
|
| 261 | + |
| 262 | + default String runInstructOnceLangChain4J(Sampler sampler, Options options, Consumer<String> tokenCallback) { |
| 263 | + State state = createNewState(); |
| 264 | + ChatFormat chatFormat = chatFormat(); |
| 265 | + TornadoVMMasterPlan tornadoVMPlan = null; |
| 266 | + |
| 267 | + List<Integer> promptTokens = new ArrayList<>(); |
| 268 | + |
| 269 | + if (shouldAddBeginOfText()) { |
| 270 | + promptTokens.add(chatFormat.getBeginOfText()); |
| 271 | + } |
| 272 | + |
| 273 | + if (shouldAddSystemPrompt() && options.systemPrompt() != null) { |
| 274 | + promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); |
| 275 | + } |
| 276 | + |
| 277 | + // Initialize TornadoVM plan once at the beginning if GPU path is enabled |
| 278 | + if (Options.getDefaultOptions().useTornadovm() && tornadoVMPlan == null) { |
| 279 | + tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this); |
| 280 | + } |
| 281 | + |
| 282 | + promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt()))); |
| 283 | + promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); |
| 284 | + |
| 285 | + if (shouldIncludeReasoning()) { |
| 286 | + List<Integer> thinkStartTokens = tokenizer().encode("<think>\n", tokenizer().getSpecialTokens().keySet()); |
| 287 | + promptTokens.addAll(thinkStartTokens); |
| 288 | + |
| 289 | + // If streaming, immediately output the think start |
| 290 | + if (options.stream()) { |
| 291 | + System.out.print("<think>\n"); |
| 292 | + } |
| 293 | + } |
| 294 | + |
| 295 | + List<Integer> responseTokens; |
| 296 | + |
| 297 | + IntConsumer tokenConsumer = token -> { |
| 298 | + if (tokenizer().shouldDisplayToken(token)) { |
| 299 | + String piece = tokenizer().decode(List.of(token)); |
| 300 | + if (options.stream() && tokenCallback != null) { |
| 301 | + tokenCallback.accept(piece); // ✅ send to LangChain4j handler |
| 302 | + } |
| 303 | + } |
| 304 | + }; |
| 305 | + |
| 306 | + Set<Integer> stopTokens = chatFormat.getStopTokens(); |
| 307 | + |
| 308 | + if (Options.getDefaultOptions().useTornadovm()) { |
| 309 | + // GPU path using TornadoVM Call generateTokensGPU without the token consumer parameter |
| 310 | + responseTokens = generateTokensGPU(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan); |
| 311 | + } else { |
| 312 | + // CPU path |
| 313 | + responseTokens = generateTokens(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer); |
| 314 | + } |
| 315 | + |
| 316 | + if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { |
| 317 | + responseTokens.removeLast(); |
| 318 | + } |
| 319 | + |
| 320 | + String responseText = tokenizer().decode(responseTokens); |
| 321 | + |
| 322 | + if (!options.stream()) { |
| 323 | + responseText = tokenizer().decode(responseTokens); |
| 324 | + if (shouldIncludeReasoning()) { |
| 325 | + responseText = "<think>\n" + responseText; |
| 326 | + } |
| 327 | + } |
| 328 | + |
| 329 | + if (tornadoVMPlan != null) { |
| 330 | + tornadoVMPlan.freeTornadoExecutionPlan(); |
| 331 | + } |
| 332 | + |
| 333 | + return responseText; |
| 334 | + } |
260 | 335 | }
|
0 commit comments