1
1
package com .example .model ;
2
2
3
+ import com .example .aux .LastRunMetrics ;
4
+ import com .example .aux .format .ChatFormat ;
5
+ import com .example .inference .InferenceEngine ;
3
6
import com .example .inference .sampler .Sampler ;
4
7
import com .example .Options ;
5
8
import com .example .loader .weights .ModelLoader .ModelType ;
8
11
import com .example .tokenizer .impl .Tokenizer ;
9
12
import com .example .tornadovm .TornadoVMMasterPlan ;
10
13
14
+ import java .util .ArrayList ;
11
15
import java .util .List ;
16
+ import java .util .Scanner ;
12
17
import java .util .Set ;
13
18
import java .util .function .IntConsumer ;
14
19
20
+ import static com .example .LlamaApp .SHOW_PERF_INTERACTIVE ;
21
+ import static com .example .LlamaApp .USE_TORNADOVM ;
22
+
15
23
public interface Model {
16
24
Configuration configuration ();
17
25
Tokenizer tokenizer ();
@@ -22,6 +30,105 @@ public interface Model {
22
30
State createNewState ();
23
31
State createNewState (int batchsize );
24
32
25
- void runInteractive (Sampler sampler , Options options );
33
+ /**
34
+ * Model agnostic default implementation for interactive mode.
35
+ * @param sampler
36
+ * @param options
37
+ */
38
+ default void runInteractive (Sampler sampler , Options options ) {
39
+ State state = null ;
40
+ List <Integer > conversationTokens = new ArrayList <>();
41
+
42
+ ChatFormat chatFormat = ChatFormat .create (tokenizer ());
43
+ conversationTokens .add (chatFormat .getBeginOfText ());
44
+
45
+ if (options .systemPrompt () != null ) {
46
+ conversationTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .SYSTEM , options .systemPrompt ())));
47
+ }
48
+
49
+ int startPosition = 0 ;
50
+ Scanner in = new Scanner (System .in );
51
+
52
+ // Initialize TornadoVM plan once at the beginning if GPU path is enabled
53
+ TornadoVMMasterPlan tornadoVMPlan = null ;
54
+
55
+ try {
56
+ while (true ) {
57
+ System .out .print ("> " );
58
+ System .out .flush ();
59
+ String userText = in .nextLine ();
60
+ if (List .of ("quit" , "exit" ).contains (userText )) {
61
+ break ;
62
+ }
63
+ if (state == null ) {
64
+ // State allocation can take some time for large context sizes,
65
+ // allocate the model state only after printing the user '>' prompt.
66
+ state = createNewState ();
67
+ }
68
+
69
+ if (USE_TORNADOVM && tornadoVMPlan == null ) {
70
+ tornadoVMPlan = TornadoVMMasterPlan .initializeTornadoVMPlan (state , this );
71
+ }
72
+
73
+ conversationTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .USER , userText )));
74
+ conversationTokens .addAll (chatFormat .encodeHeader (new ChatFormat .Message (ChatFormat .Role .ASSISTANT , "" )));
75
+ Set <Integer > stopTokens = chatFormat .getStopTokens ();
76
+
77
+ List <Integer > responseTokens ;
78
+ IntConsumer tokenConsumer = token -> {
79
+ if (options .stream ()) {
80
+ if (tokenizer ().shouldDisplayToken (token )) {
81
+ System .out .print (tokenizer ().decode (List .of (token )));
82
+ }
83
+ }
84
+ };
85
+
86
+ // Choose between GPU and CPU path based on configuration
87
+ if (USE_TORNADOVM ) {
88
+ // GPU path using TornadoVM
89
+ responseTokens = InferenceEngine .generateTokensGPU (this , state , startPosition ,
90
+ conversationTokens .subList (startPosition , conversationTokens .size ()), stopTokens ,
91
+ options .maxTokens (), sampler , options .echo (), options .stream () ? tokenConsumer : null , tornadoVMPlan );
92
+ } else {
93
+ // CPU path
94
+ responseTokens = InferenceEngine .generateTokens (this , state , startPosition ,
95
+ conversationTokens .subList (startPosition , conversationTokens .size ()), stopTokens ,
96
+ options .maxTokens (), sampler , options .echo (), tokenConsumer );
97
+ }
98
+
99
+ // Include stop token in the prompt history, but not in the response displayed to the user.
100
+ conversationTokens .addAll (responseTokens );
101
+ startPosition = conversationTokens .size ();
102
+ Integer stopToken = null ;
103
+ if (!responseTokens .isEmpty () && stopTokens .contains (responseTokens .getLast ())) {
104
+ stopToken = responseTokens .getLast ();
105
+ responseTokens .removeLast ();
106
+ }
107
+ if (!options .stream ()) {
108
+ String responseText = tokenizer ().decode (responseTokens );
109
+ System .out .println (responseText );
110
+ }
111
+ if (stopToken == null ) {
112
+ System .err .println ("\n Ran out of context length...\n Increase context length with by passing to llama-tornado --max-tokens XXX" );
113
+ break ;
114
+ }
115
+ System .out .print ("\n " );
116
+
117
+ // Optionally print performance metrics after each response
118
+ if (SHOW_PERF_INTERACTIVE ) {
119
+ LastRunMetrics .printMetrics ();
120
+ }
121
+ }
122
+ } finally {
123
+ // Clean up TornadoVM resources when exiting the chat loop
124
+ if (USE_TORNADOVM && tornadoVMPlan != null ) {
125
+ try {
126
+ tornadoVMPlan .freeTornadoExecutionPlan ();
127
+ } catch (Exception e ) {
128
+ System .err .println ("Error while cleaning up TornadoVM resources: " + e .getMessage ());
129
+ }
130
+ }
131
+ }
132
+ }
26
133
void runInstructOnce (Sampler sampler , Options options );
27
134
}
0 commit comments