Skip to content

Commit c4a3967

Browse files
committed
Add runInstructOnceLangChain4J method for LangChain4J integration; initialize TornadoVM only once per invocation, and enhance token streaming functionality.
1 parent c96e4c2 commit c4a3967

File tree

1 file changed

+69
-0
lines changed
  • src/main/java/org/beehive/gpullama3/model

1 file changed

+69
-0
lines changed

src/main/java/org/beehive/gpullama3/model/Model.java

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import java.util.List;
1414
import java.util.Scanner;
1515
import java.util.Set;
16+
import java.util.function.Consumer;
1617
import java.util.function.IntConsumer;
1718

1819
import static org.beehive.gpullama3.LlamaApp.SHOW_PERF_INTERACTIVE;
@@ -218,4 +219,72 @@ default String runInstructOnce(Sampler sampler, Options options) {
218219

219220
return responseText;
220221
}
222+
223+
/**
224+
* Model agnostic default implementation for instruct mode.
225+
*
226+
* @param sampler
227+
* @param options
228+
*/
229+
default String runInstructOnceLangChain4J(Sampler sampler, Options options, Consumer<String> tokenCallback) {
230+
State state = createNewState();
231+
ChatFormat chatFormat = chatFormat();
232+
TornadoVMMasterPlan tornadoVMPlan = null;
233+
234+
List<Integer> promptTokens = new ArrayList<>();
235+
236+
if (!getModelType().equals(ModelType.QWEN_3) && !getModelType().equals(ModelType.PHI_3)) {
237+
promptTokens.add(chatFormat.getBeginOfText());
238+
}
239+
240+
if (options.systemPrompt() != null) {
241+
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
242+
}
243+
244+
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
245+
if (USE_TORNADOVM && tornadoVMPlan == null) {
246+
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
247+
}
248+
249+
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt())));
250+
promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
251+
252+
List<Integer> responseTokens;
253+
254+
IntConsumer tokenConsumer = token -> {
255+
if (tokenizer().shouldDisplayToken(token)) {
256+
String piece = tokenizer().decode(List.of(token));
257+
if (options.stream() && tokenCallback != null) {
258+
tokenCallback.accept(piece); // ✅ send to LangChain4j handler
259+
}
260+
}
261+
};
262+
263+
Set<Integer> stopTokens = chatFormat.getStopTokens();
264+
265+
if (USE_TORNADOVM) {
266+
// GPU path using TornadoVM
267+
// Call generateTokensGPU without the token consumer parameter
268+
responseTokens = generateTokensGPU(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
269+
} else {
270+
// CPU path
271+
responseTokens = generateTokens(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer);
272+
}
273+
274+
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
275+
responseTokens.removeLast();
276+
}
277+
278+
String responseText = tokenizer().decode(responseTokens);
279+
280+
if (!options.stream()) {
281+
responseText = tokenizer().decode(responseTokens);
282+
}
283+
284+
if (tornadoVMPlan != null) {
285+
tornadoVMPlan.freeTornadoExecutionPlan();
286+
}
287+
288+
return responseText;
289+
}
221290
}

0 commit comments

Comments
 (0)