Skip to content

Commit 5224381

Browse files
Move things around for smooth interactive mode and consistency
1 parent 480a1f0 commit 5224381

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

src/main/java/com/example/model/Model.java

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,11 @@ List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> pr
6262
* @param options
6363
*/
6464
default void runInteractive(Sampler sampler, Options options) {
65-
State state = null;
65+
// Even though might be expensive, create state here for smoother interaction later
66+
State state = createNewState();
6667
List<Integer> conversationTokens = new ArrayList<>();
67-
6868
ChatFormat chatFormat = chatFormat();
69+
TornadoVMMasterPlan tornadoVMPlan = null;
6970

7071
if (!getModelType().equals(ModelType.QWEN_3)) {
7172
conversationTokens.add(chatFormat.getBeginOfText());
@@ -79,7 +80,9 @@ default void runInteractive(Sampler sampler, Options options) {
7980
Scanner in = new Scanner(System.in);
8081

8182
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
82-
TornadoVMMasterPlan tornadoVMPlan = null;
83+
if (USE_TORNADOVM && tornadoVMPlan == null) {
84+
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
85+
}
8386

8487
try {
8588
while (true) {
@@ -89,15 +92,6 @@ default void runInteractive(Sampler sampler, Options options) {
8992
if (List.of("quit", "exit").contains(userText)) {
9093
break;
9194
}
92-
if (state == null) {
93-
// State allocation can take some time for large context sizes,
94-
// allocate the model state only after printing the user '>' prompt.
95-
state = createNewState();
96-
}
97-
98-
if (USE_TORNADOVM && tornadoVMPlan == null) {
99-
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
100-
}
10195

10296
conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText)));
10397
conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
@@ -177,6 +171,12 @@ default void runInstructOnce(Sampler sampler, Options options) {
177171
if (options.systemPrompt() != null) {
178172
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
179173
}
174+
175+
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
176+
if (USE_TORNADOVM && tornadoVMPlan == null) {
177+
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
178+
}
179+
180180
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt())));
181181
promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
182182

@@ -194,7 +194,6 @@ default void runInstructOnce(Sampler sampler, Options options) {
194194

195195
if (USE_TORNADOVM) {
196196
// GPU path using TornadoVM
197-
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
198197
// Call generateTokensGPU without the token consumer parameter
199198
responseTokens = generateTokensGPU(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
200199
} else {

0 commit comments

Comments
 (0)