1
1
package com .example .model .llama ;
2
2
3
- import com .example .auxiliary .LastRunMetrics ;
4
- import com .example .auxiliary .format .LlamaChatFormat ;
5
- import com .example .inference .InferenceEngine ;
6
- import com .example .inference .sampler .Sampler ;
7
3
import com .example .model .Model ;
8
- import com .example .Options ;
9
4
import com .example .loader .weights .ModelLoader ;
10
5
import com .example .loader .weights .State ;
11
6
import com .example .loader .weights .Weights ;
12
7
import com .example .tokenizer .impl .LlamaTokenizer ;
13
8
import com .example .tokenizer .impl .Tokenizer ;
14
- import com .example .tornadovm .TornadoVMMasterPlan ;
15
-
16
- import java .util .ArrayList ;
17
- import java .util .List ;
18
- import java .util .Set ;
19
- import java .util .function .IntConsumer ;
20
-
21
- import static com .example .LlamaApp .USE_TORNADOVM ;
22
9
23
10
public record Llama (LlamaConfiguration configuration , Tokenizer tokenizer , Weights weights ) implements Model {
24
11
private static final int BATCH_SIZE = Integer .getInteger ("llama.BatchSize" , 16 );
@@ -45,57 +32,5 @@ public State createNewState(int batchsize) {
45
32
return state ;
46
33
}
47
34
48
- @ Override
49
- public void runInstructOnce (Sampler sampler , Options options ) {
50
- State state = createNewState ();
51
- LlamaChatFormat chatFormat = new LlamaChatFormat (getAsLlamaTokenizer ());
52
- TornadoVMMasterPlan tornadoVMPlan = null ;
53
-
54
- List <Integer > promptTokens = new ArrayList <>();
55
- promptTokens .add (chatFormat .getBeginOfText ());
56
-
57
- if (options .systemPrompt () != null ) {
58
- promptTokens .addAll (chatFormat .encodeMessage (new LlamaChatFormat .Message (LlamaChatFormat .Role .SYSTEM , options .systemPrompt ())));
59
- }
60
- promptTokens .addAll (chatFormat .encodeMessage (new LlamaChatFormat .Message (LlamaChatFormat .Role .USER , options .prompt ())));
61
- promptTokens .addAll (chatFormat .encodeHeader (new LlamaChatFormat .Message (LlamaChatFormat .Role .ASSISTANT , "" )));
62
- List <Integer > responseTokens ;
63
-
64
- // Define the token consumer
65
- IntConsumer tokenConsumer = token -> {
66
- if (options .stream ()) {
67
- if (!tokenizer .isSpecialToken (token )) {
68
- System .out .print (tokenizer .decode (List .of (token )));
69
- }
70
- }
71
- };
72
-
73
- Set <Integer > stopTokens = chatFormat .getStopTokens ();
74
- if (USE_TORNADOVM ) {
75
- tornadoVMPlan = TornadoVMMasterPlan .initializeTornadoVMPlan (state , this );
76
- // Call generateTokensGPU without the token consumer parameter
77
- responseTokens = InferenceEngine .generateTokensGPU (this , state , 0 , promptTokens , stopTokens ,
78
- options .maxTokens (), sampler , options .echo (), options .stream () ? tokenConsumer : null , tornadoVMPlan );
79
- } else {
80
- // CPU path still uses the token consumer
81
- responseTokens = InferenceEngine .generateTokens (this , state , 0 , promptTokens , stopTokens ,
82
- options .maxTokens (), sampler , options .echo (), tokenConsumer );
83
- }
84
-
85
- if (!responseTokens .isEmpty () && stopTokens .contains (responseTokens .getLast ())) {
86
- responseTokens .removeLast ();
87
- }
88
- if (!options .stream ()) {
89
- String responseText = tokenizer .decode (responseTokens );
90
- System .out .println (responseText );
91
- }
92
-
93
- LastRunMetrics .printMetrics ();
94
-
95
- if (tornadoVMPlan != null ) {
96
- tornadoVMPlan .freeTornadoExecutionPlan ();
97
- }
98
- }
99
-
100
35
}
101
36
0 commit comments