|
21 | 21 |
|
22 | 22 | public interface Model {
|
23 | 23 | Configuration configuration();
|
| 24 | + |
24 | 25 | Tokenizer tokenizer();
|
| 26 | + |
25 | 27 | Weights weights();
|
26 | 28 |
|
27 | 29 | ModelType getModelType();
|
28 | 30 |
|
29 | 31 | State createNewState();
|
| 32 | + |
30 | 33 | State createNewState(int batchsize);
|
31 | 34 |
|
32 | 35 | /**
|
@@ -85,14 +88,12 @@ default void runInteractive(Sampler sampler, Options options) {
|
85 | 88 | // Choose between GPU and CPU path based on configuration
|
86 | 89 | if (USE_TORNADOVM) {
|
87 | 90 | // GPU path using TornadoVM
|
88 |
| - responseTokens = InferenceEngine.generateTokensGPU(this, state, startPosition, |
89 |
| - conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, |
| 91 | + responseTokens = InferenceEngine.generateTokensGPU(this, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, |
90 | 92 | options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
|
91 | 93 | } else {
|
92 | 94 | // CPU path
|
93 |
| - responseTokens = InferenceEngine.generateTokens(this, state, startPosition, |
94 |
| - conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, |
95 |
| - options.maxTokens(), sampler, options.echo(), tokenConsumer); |
| 95 | + responseTokens = InferenceEngine.generateTokens(this, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), |
| 96 | + sampler, options.echo(), tokenConsumer); |
96 | 97 | }
|
97 | 98 |
|
98 | 99 | // Include stop token in the prompt history, but not in the response displayed to the user.
|
@@ -164,11 +165,10 @@ default void runInstructOnce(Sampler sampler, Options options) {
|
164 | 165 | if (USE_TORNADOVM) {
|
165 | 166 | tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
|
166 | 167 | // Call generateTokensGPU without the token consumer parameter
|
167 |
| - responseTokens = InferenceEngine.generateTokensGPU(this, state, 0, promptTokens, stopTokens, |
168 |
| - options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan); |
| 168 | + responseTokens = InferenceEngine.generateTokensGPU(this, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, |
| 169 | + tornadoVMPlan); |
169 | 170 | } else {
|
170 |
| - responseTokens = InferenceEngine.generateTokens(this, state, 0, promptTokens, stopTokens, |
171 |
| - options.maxTokens(), sampler, options.echo(), tokenConsumer); |
| 171 | + responseTokens = InferenceEngine.generateTokens(this, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer); |
172 | 172 | }
|
173 | 173 |
|
174 | 174 | if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
|
|
0 commit comments