|
13 | 13 | import java.util.List;
|
14 | 14 | import java.util.Scanner;
|
15 | 15 | import java.util.Set;
|
| 16 | +import java.util.function.Consumer; |
16 | 17 | import java.util.function.IntConsumer;
|
17 | 18 |
|
18 | 19 | import static org.beehive.gpullama3.LlamaApp.SHOW_PERF_INTERACTIVE;
|
@@ -218,4 +219,72 @@ default String runInstructOnce(Sampler sampler, Options options) {
|
218 | 219 |
|
219 | 220 | return responseText;
|
220 | 221 | }
|
| 222 | + |
| 223 | + /** |
| 224 | + * Model agnostic default implementation for instruct mode. |
| 225 | + * |
| 226 | + * @param sampler |
| 227 | + * @param options |
| 228 | + */ |
| 229 | + default String runInstructOnceLangChain4J(Sampler sampler, Options options, Consumer<String> tokenCallback) { |
| 230 | + State state = createNewState(); |
| 231 | + ChatFormat chatFormat = chatFormat(); |
| 232 | + TornadoVMMasterPlan tornadoVMPlan = null; |
| 233 | + |
| 234 | + List<Integer> promptTokens = new ArrayList<>(); |
| 235 | + |
| 236 | + if (!getModelType().equals(ModelType.QWEN_3) && !getModelType().equals(ModelType.PHI_3)) { |
| 237 | + promptTokens.add(chatFormat.getBeginOfText()); |
| 238 | + } |
| 239 | + |
| 240 | + if (options.systemPrompt() != null) { |
| 241 | + promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); |
| 242 | + } |
| 243 | + |
| 244 | + // Initialize TornadoVM plan once at the beginning if GPU path is enabled |
| 245 | + if (USE_TORNADOVM && tornadoVMPlan == null) { |
| 246 | + tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this); |
| 247 | + } |
| 248 | + |
| 249 | + promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt()))); |
| 250 | + promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); |
| 251 | + |
| 252 | + List<Integer> responseTokens; |
| 253 | + |
| 254 | + IntConsumer tokenConsumer = token -> { |
| 255 | + if (tokenizer().shouldDisplayToken(token)) { |
| 256 | + String piece = tokenizer().decode(List.of(token)); |
| 257 | + if (options.stream() && tokenCallback != null) { |
| 258 | + tokenCallback.accept(piece); // ✅ send to LangChain4j handler |
| 259 | + } |
| 260 | + } |
| 261 | + }; |
| 262 | + |
| 263 | + Set<Integer> stopTokens = chatFormat.getStopTokens(); |
| 264 | + |
| 265 | + if (USE_TORNADOVM) { |
| 266 | + // GPU path using TornadoVM |
| 267 | + // Call generateTokensGPU without the token consumer parameter |
| 268 | + responseTokens = generateTokensGPU(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan); |
| 269 | + } else { |
| 270 | + // CPU path |
| 271 | + responseTokens = generateTokens(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer); |
| 272 | + } |
| 273 | + |
| 274 | + if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { |
| 275 | + responseTokens.removeLast(); |
| 276 | + } |
| 277 | + |
| 278 | + String responseText = tokenizer().decode(responseTokens); |
| 279 | + |
| 280 | + if (!options.stream()) { |
| 281 | + responseText = tokenizer().decode(responseTokens); |
| 282 | + } |
| 283 | + |
| 284 | + if (tornadoVMPlan != null) { |
| 285 | + tornadoVMPlan.freeTornadoExecutionPlan(); |
| 286 | + } |
| 287 | + |
| 288 | + return responseText; |
| 289 | + } |
221 | 290 | }
|
0 commit comments