1
1
package org .beehive .gpullama3 .api .service ;
2
2
3
- import org . beehive . gpullama3 . model . Model ;
4
- import org .beehive .gpullama3 .inference . state . State ;
3
+ import jakarta . annotation . PostConstruct ;
4
+ import org .beehive .gpullama3 .Options ;
5
5
import org .beehive .gpullama3 .inference .sampler .Sampler ;
6
- import org .springframework .beans .factory .annotation .Autowired ;
6
+ import org .beehive .gpullama3 .inference .state .State ;
7
+ import org .beehive .gpullama3 .model .Model ;
8
+ import org .beehive .gpullama3 .model .format .ChatFormat ;
9
+ import org .beehive .gpullama3 .model .loader .ModelLoader ;
10
+ import org .springframework .boot .ApplicationArguments ;
7
11
import org .springframework .stereotype .Service ;
8
12
import org .springframework .web .servlet .mvc .method .annotation .SseEmitter ;
9
13
10
- import java .util .*;
14
+ import java .util .ArrayList ;
15
+ import java .util .List ;
16
+ import java .util .Set ;
11
17
import java .util .concurrent .CompletableFuture ;
12
- import java .util .function .IntConsumer ;
18
+
19
+ import static org .beehive .gpullama3 .inference .sampler .Sampler .selectSampler ;
20
+ import static org .beehive .gpullama3 .model .loader .ModelLoader .loadModel ;
13
21
14
22
@ Service
15
23
public class LLMService {
16
24
17
- @ Autowired
18
- private ModelInitializationService initService ;
25
+ private final ApplicationArguments args ;
19
26
20
- @ Autowired
21
- private TokenizerService tokenizerService ;
27
+ private Options options ;
28
+ private Model model ;
22
29
23
- public CompletableFuture <String > generateCompletion (
24
- String prompt ,
25
- int maxTokens ,
26
- double temperature ,
27
- double topP ,
28
- List <String > stopSequences ) {
30
+ public LLMService (ApplicationArguments args ) {
31
+ this .args = args ;
32
+ }
29
33
30
- return CompletableFuture .supplyAsync (() -> {
31
- try {
32
- System .out .println ("Starting completion generation..." );
33
- System .out .println ("Prompt: " + prompt .substring (0 , Math .min (50 , prompt .length ())) + "..." );
34
- System .out .println ("Max tokens: " + maxTokens + ", Temperature: " + temperature );
35
-
36
- // Get initialized components
37
- Model model = initService .getModel ();
38
-
39
- // Convert prompt to tokens
40
- List <Integer > promptTokens = tokenizerService .encode (prompt );
41
- System .out .println ("Prompt tokens: " + promptTokens .size ());
42
-
43
- // Convert stop sequences to token sets
44
- Set <Integer > stopTokens = new HashSet <>();
45
- if (stopSequences != null ) {
46
- for (String stop : stopSequences ) {
47
- stopTokens .addAll (tokenizerService .encode (stop ));
48
- }
49
- System .out .println ("Stop tokens: " + stopTokens .size ());
50
- }
34
+ @ PostConstruct
35
+ public void init () {
36
+ try {
37
+ System .out .println ("Initializing LLM service..." );
38
+
39
+ // Step 1: Parse service options
40
+ System .out .println ("Step 1: Parsing service options..." );
41
+ options = Options .parseServiceOptions (args .getSourceArgs ());
42
+ System .out .println ("Model path: " + options .modelPath ());
43
+ System .out .println ("Context length: " + options .maxTokens ());
44
+
45
+ // Step 2: Load model weights
46
+ System .out .println ("\n Step 2: Loading model..." );
47
+ System .out .println ("Loading model from: " + options .modelPath ());
48
+ model = ModelLoader .loadModel (options .modelPath (), options .maxTokens (), true );
49
+ System .out .println ("✓ Model loaded successfully" );
50
+ System .out .println (" Model type: " + model .getClass ().getSimpleName ());
51
+ System .out .println (" Vocabulary size: " + model .configuration ().vocabularySize ());
52
+ System .out .println (" Context length: " + model .configuration ().contextLength ());
53
+
54
+ System .out .println ("\n ✓ Model service initialization completed successfully!" );
55
+ System .out .println ("=== Ready to serve requests ===\n " );
51
56
52
- // Create custom sampler with request-specific parameters
53
- //Sampler sampler = initService.createCustomSampler(temperature, topP, System.currentTimeMillis());
54
- Sampler sampler = initService .getSampler ();
57
+ } catch (Exception e ) {
58
+ System .err .println ("✗ Failed to initialize model service: " + e .getMessage ());
59
+ e .printStackTrace ();
60
+ throw new RuntimeException ("Model initialization failed" , e );
61
+ }
62
+ }
55
63
56
- // Create state based on model type
57
- State state = createStateForModel (model );
64
+ public String generateResponse (String message , String systemMessage ) {
65
+ return generateResponse (message , systemMessage , 150 , 0.7 , 0.9 );
66
+ }
58
67
59
- // Generate tokens using your existing method
60
- List <Integer > generatedTokens = model .generateTokens (
61
- state ,
62
- 0 ,
63
- promptTokens ,
64
- stopTokens ,
65
- maxTokens ,
66
- sampler ,
67
- false ,
68
- token -> {} // No callback for non-streaming
69
- );
68
+ public String generateResponse (String message , String systemMessage , int maxTokens , double temperature , double topP ) {
69
+ try {
70
+ // Create sampler and state like runInstructOnce
71
+ Sampler sampler = selectSampler (model .configuration ().vocabularySize (), (float ) temperature , (float ) topP , System .currentTimeMillis ());
72
+ State state = model .createNewState ();
70
73
71
- // Decode tokens back to text
72
- String result = tokenizerService .decode (generatedTokens );
73
- System .out .println ("Generated " + generatedTokens .size () + " tokens" );
74
- System .out .println ("Completion finished successfully" );
74
+ // Use model's ChatFormat
75
+ ChatFormat chatFormat = model .chatFormat ();
76
+ List <Integer > promptTokens = new ArrayList <>();
75
77
76
- return result ;
78
+ // Add begin of text if needed
79
+ if (model .shouldAddBeginOfText ()) {
80
+ promptTokens .add (chatFormat .getBeginOfText ());
81
+ }
77
82
78
- } catch (Exception e ) {
79
- System .err .println ("Error generating completion: " + e .getMessage ());
80
- e .printStackTrace ();
81
- throw new RuntimeException ("Error generating completion" , e );
83
+ // Add system message properly formatted
84
+ if (model .shouldAddSystemPrompt () && systemMessage != null && !systemMessage .trim ().isEmpty ()) {
85
+ promptTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .SYSTEM , systemMessage )));
82
86
}
83
- });
84
- }
85
87
86
- public void generateStreamingCompletion (
87
- String prompt ,
88
- int maxTokens ,
89
- double temperature ,
90
- double topP ,
91
- List <String > stopSequences ,
92
- SseEmitter emitter ) {
88
+ // Add user message properly formatted
89
+ promptTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .USER , message )));
90
+ promptTokens .addAll (chatFormat .encodeHeader (new ChatFormat .Message (ChatFormat .Role .ASSISTANT , "" )));
91
+
92
+ // Handle reasoning tokens if needed (for Deepseek-R1-Distill-Qwen)
93
+ if (model .shouldIncludeReasoning ()) {
94
+ List <Integer > thinkStartTokens = model .tokenizer ().encode ("<think>\n " , model .tokenizer ().getSpecialTokens ().keySet ());
95
+ promptTokens .addAll (thinkStartTokens );
96
+ }
97
+
98
+ // Use proper stop tokens from chat format
99
+ Set <Integer > stopTokens = chatFormat .getStopTokens ();
100
+
101
+ long startTime = System .currentTimeMillis ();
102
+
103
+ // Use CPU path for now (GPU path disabled as noted)
104
+ List <Integer > generatedTokens = model .generateTokens (
105
+ state , 0 , promptTokens , stopTokens , maxTokens , sampler , false , token -> {}
106
+ );
93
107
108
+ // Remove stop tokens if present
109
+ if (!generatedTokens .isEmpty () && stopTokens .contains (generatedTokens .getLast ())) {
110
+ generatedTokens .removeLast ();
111
+ }
112
+
113
+ long duration = System .currentTimeMillis () - startTime ;
114
+ double tokensPerSecond = generatedTokens .size () * 1000.0 / duration ;
115
+ System .out .printf ("COMPLETED tokens=%d duration=%dms rate=%.1f tok/s%n" ,
116
+ generatedTokens .size (), duration , tokensPerSecond );
117
+
118
+
119
+ String responseText = model .tokenizer ().decode (generatedTokens );
120
+
121
+ // Add reasoning prefix for non-streaming if needed
122
+ if (model .shouldIncludeReasoning ()) {
123
+ responseText = "<think>\n " + responseText ;
124
+ }
125
+
126
+ return responseText ;
127
+
128
+ } catch (Exception e ) {
129
+ System .err .println ("FAILED " + e .getMessage ());
130
+ throw new RuntimeException ("Failed to generate response" , e );
131
+ }
132
+ }
133
+
134
+ public void generateStreamingResponse (String message , String systemMessage , SseEmitter emitter ) {
94
135
CompletableFuture .runAsync (() -> {
95
136
try {
96
- System .out .println ("Starting streaming completion generation..." );
97
-
98
- Model model = initService .getModel ();
137
+ Sampler sampler = selectSampler (model .configuration ().vocabularySize (), 0.7f , 0.9f , System .currentTimeMillis ());
138
+ State state = model .createNewState ();
99
139
100
- List <Integer > promptTokens = tokenizerService .encode (prompt );
140
+ // Use proper chat format like in runInstructOnce
141
+ ChatFormat chatFormat = model .chatFormat ();
142
+ List <Integer > promptTokens = new ArrayList <>();
101
143
102
- Set <Integer > stopTokens = new HashSet <>();
103
- if (stopSequences != null ) {
104
- for (String stop : stopSequences ) {
105
- stopTokens .addAll (tokenizerService .encode (stop ));
106
- }
144
+ if (model .shouldAddBeginOfText ()) {
145
+ promptTokens .add (chatFormat .getBeginOfText ());
107
146
}
108
147
109
- //Sampler sampler = initService.createCustomSampler(temperature, topP, System.currentTimeMillis());
110
- Sampler sampler = initService .getSampler ();
111
- State state = createStateForModel (model );
112
-
113
- final int [] tokenCount = {0 };
148
+ if (model .shouldAddSystemPrompt () && systemMessage != null && !systemMessage .trim ().isEmpty ()) {
149
+ promptTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .SYSTEM , systemMessage )));
150
+ }
114
151
115
- // Streaming callback
116
- IntConsumer tokenCallback = token -> {
117
- try {
118
- String tokenText = tokenizerService .decode (List .of (token ));
119
- tokenCount [0 ]++;
152
+ promptTokens .addAll (chatFormat .encodeMessage (new ChatFormat .Message (ChatFormat .Role .USER , message )));
153
+ promptTokens .addAll (chatFormat .encodeHeader (new ChatFormat .Message (ChatFormat .Role .ASSISTANT , "" )));
120
154
121
- String eventData = String .format (
122
- "data: {\" choices\" :[{\" text\" :\" %s\" ,\" index\" :0,\" finish_reason\" :null}]}\n \n " ,
123
- escapeJson (tokenText )
124
- );
155
+ // Handle reasoning tokens for streaming
156
+ if (model .shouldIncludeReasoning ()) {
157
+ List <Integer > thinkStartTokens = model .tokenizer ().encode ("<think>\n " , model .tokenizer ().getSpecialTokens ().keySet ());
158
+ promptTokens .addAll (thinkStartTokens );
159
+ emitter .send (SseEmitter .event ().data ("<think>\n " )); // Output immediately
160
+ }
125
161
126
- emitter . send ( SseEmitter . event (). data ( eventData ) );
162
+ Set < Integer > stopTokens = chatFormat . getStopTokens ( );
127
163
128
- if (tokenCount [0 ] % 10 == 0 ) {
129
- System .out .println ("Streamed " + tokenCount [0 ] + " tokens" );
164
+ final int [] tokenCount = {0 };
165
+ long startTime = System .currentTimeMillis ();
166
+ List <Integer > generatedTokens = model .generateTokens (
167
+ state , 0 , promptTokens , stopTokens , 150 , sampler , false ,
168
+ token -> {
169
+ try {
170
+ // Only display tokens that should be displayed (like in your original)
171
+ if (model .tokenizer ().shouldDisplayToken (token )) {
172
+ String tokenText = model .tokenizer ().decode (List .of (token ));
173
+ emitter .send (SseEmitter .event ().data (tokenText ));
174
+ tokenCount [0 ]++;
175
+ }
176
+ } catch (Exception e ) {
177
+ emitter .completeWithError (e );
178
+ }
130
179
}
180
+ );
131
181
132
- } catch (Exception e ) {
133
- System .err .println ("Error in streaming callback: " + e .getMessage ());
134
- emitter .completeWithError (e );
135
- }
136
- };
137
-
138
- model .generateTokens (state , 0 , promptTokens , stopTokens , maxTokens , sampler , false , tokenCallback );
182
+ long duration = System .currentTimeMillis () - startTime ;
183
+ double tokensPerSecond = tokenCount [0 ] * 1000.0 / duration ;
184
+ System .out .printf ("COMPLETED tokens=%d duration=%dms rate=%.1f tok/s%n" ,
185
+ tokenCount [0 ], duration , tokensPerSecond );
139
186
140
- // Send completion event
141
- emitter .send (SseEmitter .event ().data ("data: [DONE]\n \n " ));
187
+ emitter .send (SseEmitter .event ().data ("[DONE]" ));
142
188
emitter .complete ();
143
189
144
- System .out .println ("Streaming completion finished. Total tokens: " + tokenCount [0 ]);
145
-
146
190
} catch (Exception e ) {
147
- System .err .println ("Error in streaming generation: " + e .getMessage ());
148
- e .printStackTrace ();
191
+ System .err .println ("FAILED " + e .getMessage ());
149
192
emitter .completeWithError (e );
150
193
}
151
194
});
152
195
}
153
196
154
- /**
155
- * Create appropriate State subclass based on the model type
156
- */
157
- private State createStateForModel (Model model ) {
158
- try {
159
- return model .createNewState ();
160
- } catch (Exception e ) {
161
- throw new RuntimeException ("Failed to create state for model" , e );
197
+ // Getters for other services to access the initialized components
198
+ public Options getOptions () {
199
+ if (options == null ) {
200
+ throw new IllegalStateException ("Model service not initialized yet" );
162
201
}
202
+ return options ;
163
203
}
164
204
165
- private String escapeJson (String str ) {
166
- if (str == null ) return "" ;
167
- return str .replace ("\" " , "\\ \" " )
168
- .replace ("\n " , "\\ n" )
169
- .replace ("\r " , "\\ r" )
170
- .replace ("\t " , "\\ t" )
171
- .replace ("\\ " , "\\ \\ " );
205
+ public Model getModel () {
206
+ if (model == null ) {
207
+ throw new IllegalStateException ("Model service not initialized yet" );
208
+ }
209
+ return model ;
172
210
}
173
211
}
0 commit comments