Skip to content

Commit 55e4b34

Browse files
[WIP] Update service logic
1 parent 7bab2d4 commit 55e4b34

File tree

2 files changed

+167
-129
lines changed

2 files changed

+167
-129
lines changed
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package org.beehive.gpullama3.api.config;
22

3+
import org.beehive.gpullama3.api.service.LLMService;
34
import org.beehive.gpullama3.model.Model;
45
import org.beehive.gpullama3.Options;
5-
import org.beehive.gpullama3.api.service.ModelInitializationService;
66
import org.springframework.context.annotation.Bean;
77
import org.springframework.context.annotation.Configuration;
88

@@ -13,15 +13,15 @@ public class ModelConfiguration {
1313
* Expose Model as a Spring bean using the initialized service
1414
*/
1515
@Bean
16-
public Model model(ModelInitializationService initService) {
17-
return initService.getModel();
16+
public Model model(LLMService llmService) {
17+
return llmService.getModel();
1818
}
1919

2020
/**
2121
* Expose Options as a Spring bean using the initialized service
2222
*/
2323
@Bean
24-
public Options options(ModelInitializationService initService) {
25-
return initService.getOptions();
24+
public Options options(LLMService llmService) {
25+
return llmService.getOptions();
2626
}
2727
}
Lines changed: 162 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -1,173 +1,211 @@
11
package org.beehive.gpullama3.api.service;
22

3-
import org.beehive.gpullama3.model.Model;
4-
import org.beehive.gpullama3.inference.state.State;
3+
import jakarta.annotation.PostConstruct;
4+
import org.beehive.gpullama3.Options;
55
import org.beehive.gpullama3.inference.sampler.Sampler;
6-
import org.springframework.beans.factory.annotation.Autowired;
6+
import org.beehive.gpullama3.inference.state.State;
7+
import org.beehive.gpullama3.model.Model;
8+
import org.beehive.gpullama3.model.format.ChatFormat;
9+
import org.beehive.gpullama3.model.loader.ModelLoader;
10+
import org.springframework.boot.ApplicationArguments;
711
import org.springframework.stereotype.Service;
812
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
913

10-
import java.util.*;
14+
import java.util.ArrayList;
15+
import java.util.List;
16+
import java.util.Set;
1117
import java.util.concurrent.CompletableFuture;
12-
import java.util.function.IntConsumer;
18+
19+
import static org.beehive.gpullama3.inference.sampler.Sampler.selectSampler;
20+
import static org.beehive.gpullama3.model.loader.ModelLoader.loadModel;
1321

1422
@Service
1523
public class LLMService {
1624

17-
@Autowired
18-
private ModelInitializationService initService;
25+
private final ApplicationArguments args;
1926

20-
@Autowired
21-
private TokenizerService tokenizerService;
27+
private Options options;
28+
private Model model;
2229

23-
public CompletableFuture<String> generateCompletion(
24-
String prompt,
25-
int maxTokens,
26-
double temperature,
27-
double topP,
28-
List<String> stopSequences) {
30+
public LLMService(ApplicationArguments args) {
31+
this.args = args;
32+
}
2933

30-
return CompletableFuture.supplyAsync(() -> {
31-
try {
32-
System.out.println("Starting completion generation...");
33-
System.out.println("Prompt: " + prompt.substring(0, Math.min(50, prompt.length())) + "...");
34-
System.out.println("Max tokens: " + maxTokens + ", Temperature: " + temperature);
35-
36-
// Get initialized components
37-
Model model = initService.getModel();
38-
39-
// Convert prompt to tokens
40-
List<Integer> promptTokens = tokenizerService.encode(prompt);
41-
System.out.println("Prompt tokens: " + promptTokens.size());
42-
43-
// Convert stop sequences to token sets
44-
Set<Integer> stopTokens = new HashSet<>();
45-
if (stopSequences != null) {
46-
for (String stop : stopSequences) {
47-
stopTokens.addAll(tokenizerService.encode(stop));
48-
}
49-
System.out.println("Stop tokens: " + stopTokens.size());
50-
}
34+
@PostConstruct
35+
public void init() {
36+
try {
37+
System.out.println("Initializing LLM service...");
38+
39+
// Step 1: Parse service options
40+
System.out.println("Step 1: Parsing service options...");
41+
options = Options.parseServiceOptions(args.getSourceArgs());
42+
System.out.println("Model path: " + options.modelPath());
43+
System.out.println("Context length: " + options.maxTokens());
44+
45+
// Step 2: Load model weights
46+
System.out.println("\nStep 2: Loading model...");
47+
System.out.println("Loading model from: " + options.modelPath());
48+
model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true);
49+
System.out.println("✓ Model loaded successfully");
50+
System.out.println(" Model type: " + model.getClass().getSimpleName());
51+
System.out.println(" Vocabulary size: " + model.configuration().vocabularySize());
52+
System.out.println(" Context length: " + model.configuration().contextLength());
53+
54+
System.out.println("\n✓ Model service initialization completed successfully!");
55+
System.out.println("=== Ready to serve requests ===\n");
5156

52-
// Create custom sampler with request-specific parameters
53-
//Sampler sampler = initService.createCustomSampler(temperature, topP, System.currentTimeMillis());
54-
Sampler sampler = initService.getSampler();
57+
} catch (Exception e) {
58+
System.err.println("✗ Failed to initialize model service: " + e.getMessage());
59+
e.printStackTrace();
60+
throw new RuntimeException("Model initialization failed", e);
61+
}
62+
}
5563

56-
// Create state based on model type
57-
State state = createStateForModel(model);
64+
public String generateResponse(String message, String systemMessage) {
65+
return generateResponse(message, systemMessage, 150, 0.7, 0.9);
66+
}
5867

59-
// Generate tokens using your existing method
60-
List<Integer> generatedTokens = model.generateTokens(
61-
state,
62-
0,
63-
promptTokens,
64-
stopTokens,
65-
maxTokens,
66-
sampler,
67-
false,
68-
token -> {} // No callback for non-streaming
69-
);
68+
public String generateResponse(String message, String systemMessage, int maxTokens, double temperature, double topP) {
69+
try {
70+
// Create sampler and state like runInstructOnce
71+
Sampler sampler = selectSampler(model.configuration().vocabularySize(), (float) temperature, (float) topP, System.currentTimeMillis());
72+
State state = model.createNewState();
7073

71-
// Decode tokens back to text
72-
String result = tokenizerService.decode(generatedTokens);
73-
System.out.println("Generated " + generatedTokens.size() + " tokens");
74-
System.out.println("Completion finished successfully");
74+
// Use model's ChatFormat
75+
ChatFormat chatFormat = model.chatFormat();
76+
List<Integer> promptTokens = new ArrayList<>();
7577

76-
return result;
78+
// Add begin of text if needed
79+
if (model.shouldAddBeginOfText()) {
80+
promptTokens.add(chatFormat.getBeginOfText());
81+
}
7782

78-
} catch (Exception e) {
79-
System.err.println("Error generating completion: " + e.getMessage());
80-
e.printStackTrace();
81-
throw new RuntimeException("Error generating completion", e);
83+
// Add system message properly formatted
84+
if (model.shouldAddSystemPrompt() && systemMessage != null && !systemMessage.trim().isEmpty()) {
85+
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, systemMessage)));
8286
}
83-
});
84-
}
8587

86-
public void generateStreamingCompletion(
87-
String prompt,
88-
int maxTokens,
89-
double temperature,
90-
double topP,
91-
List<String> stopSequences,
92-
SseEmitter emitter) {
88+
// Add user message properly formatted
89+
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, message)));
90+
promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
91+
92+
// Handle reasoning tokens if needed (for Deepseek-R1-Distill-Qwen)
93+
if (model.shouldIncludeReasoning()) {
94+
List<Integer> thinkStartTokens = model.tokenizer().encode("<think>\n", model.tokenizer().getSpecialTokens().keySet());
95+
promptTokens.addAll(thinkStartTokens);
96+
}
97+
98+
// Use proper stop tokens from chat format
99+
Set<Integer> stopTokens = chatFormat.getStopTokens();
100+
101+
long startTime = System.currentTimeMillis();
102+
103+
// Use CPU path for now (GPU path disabled as noted)
104+
List<Integer> generatedTokens = model.generateTokens(
105+
state, 0, promptTokens, stopTokens, maxTokens, sampler, false, token -> {}
106+
);
93107

108+
// Remove stop tokens if present
109+
if (!generatedTokens.isEmpty() && stopTokens.contains(generatedTokens.getLast())) {
110+
generatedTokens.removeLast();
111+
}
112+
113+
long duration = System.currentTimeMillis() - startTime;
114+
double tokensPerSecond = generatedTokens.size() * 1000.0 / duration;
115+
System.out.printf("COMPLETED tokens=%d duration=%dms rate=%.1f tok/s%n",
116+
generatedTokens.size(), duration, tokensPerSecond);
117+
118+
119+
String responseText = model.tokenizer().decode(generatedTokens);
120+
121+
// Add reasoning prefix for non-streaming if needed
122+
if (model.shouldIncludeReasoning()) {
123+
responseText = "<think>\n" + responseText;
124+
}
125+
126+
return responseText;
127+
128+
} catch (Exception e) {
129+
System.err.println("FAILED " + e.getMessage());
130+
throw new RuntimeException("Failed to generate response", e);
131+
}
132+
}
133+
134+
public void generateStreamingResponse(String message, String systemMessage, SseEmitter emitter) {
94135
CompletableFuture.runAsync(() -> {
95136
try {
96-
System.out.println("Starting streaming completion generation...");
97-
98-
Model model = initService.getModel();
137+
Sampler sampler = selectSampler(model.configuration().vocabularySize(), 0.7f, 0.9f, System.currentTimeMillis());
138+
State state = model.createNewState();
99139

100-
List<Integer> promptTokens = tokenizerService.encode(prompt);
140+
// Use proper chat format like in runInstructOnce
141+
ChatFormat chatFormat = model.chatFormat();
142+
List<Integer> promptTokens = new ArrayList<>();
101143

102-
Set<Integer> stopTokens = new HashSet<>();
103-
if (stopSequences != null) {
104-
for (String stop : stopSequences) {
105-
stopTokens.addAll(tokenizerService.encode(stop));
106-
}
144+
if (model.shouldAddBeginOfText()) {
145+
promptTokens.add(chatFormat.getBeginOfText());
107146
}
108147

109-
//Sampler sampler = initService.createCustomSampler(temperature, topP, System.currentTimeMillis());
110-
Sampler sampler = initService.getSampler();
111-
State state = createStateForModel(model);
112-
113-
final int[] tokenCount = {0};
148+
if (model.shouldAddSystemPrompt() && systemMessage != null && !systemMessage.trim().isEmpty()) {
149+
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, systemMessage)));
150+
}
114151

115-
// Streaming callback
116-
IntConsumer tokenCallback = token -> {
117-
try {
118-
String tokenText = tokenizerService.decode(List.of(token));
119-
tokenCount[0]++;
152+
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, message)));
153+
promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
120154

121-
String eventData = String.format(
122-
"data: {\"choices\":[{\"text\":\"%s\",\"index\":0,\"finish_reason\":null}]}\n\n",
123-
escapeJson(tokenText)
124-
);
155+
// Handle reasoning tokens for streaming
156+
if (model.shouldIncludeReasoning()) {
157+
List<Integer> thinkStartTokens = model.tokenizer().encode("<think>\n", model.tokenizer().getSpecialTokens().keySet());
158+
promptTokens.addAll(thinkStartTokens);
159+
emitter.send(SseEmitter.event().data("<think>\n")); // Output immediately
160+
}
125161

126-
emitter.send(SseEmitter.event().data(eventData));
162+
Set<Integer> stopTokens = chatFormat.getStopTokens();
127163

128-
if (tokenCount[0] % 10 == 0) {
129-
System.out.println("Streamed " + tokenCount[0] + " tokens");
164+
final int[] tokenCount = {0};
165+
long startTime = System.currentTimeMillis();
166+
List<Integer> generatedTokens = model.generateTokens(
167+
state, 0, promptTokens, stopTokens, 150, sampler, false,
168+
token -> {
169+
try {
170+
// Only display tokens that should be displayed (like in your original)
171+
if (model.tokenizer().shouldDisplayToken(token)) {
172+
String tokenText = model.tokenizer().decode(List.of(token));
173+
emitter.send(SseEmitter.event().data(tokenText));
174+
tokenCount[0]++;
175+
}
176+
} catch (Exception e) {
177+
emitter.completeWithError(e);
178+
}
130179
}
180+
);
131181

132-
} catch (Exception e) {
133-
System.err.println("Error in streaming callback: " + e.getMessage());
134-
emitter.completeWithError(e);
135-
}
136-
};
137-
138-
model.generateTokens(state, 0, promptTokens, stopTokens, maxTokens, sampler, false, tokenCallback);
182+
long duration = System.currentTimeMillis() - startTime;
183+
double tokensPerSecond = tokenCount[0] * 1000.0 / duration;
184+
System.out.printf("COMPLETED tokens=%d duration=%dms rate=%.1f tok/s%n",
185+
tokenCount[0], duration, tokensPerSecond);
139186

140-
// Send completion event
141-
emitter.send(SseEmitter.event().data("data: [DONE]\n\n"));
187+
emitter.send(SseEmitter.event().data("[DONE]"));
142188
emitter.complete();
143189

144-
System.out.println("Streaming completion finished. Total tokens: " + tokenCount[0]);
145-
146190
} catch (Exception e) {
147-
System.err.println("Error in streaming generation: " + e.getMessage());
148-
e.printStackTrace();
191+
System.err.println("FAILED " + e.getMessage());
149192
emitter.completeWithError(e);
150193
}
151194
});
152195
}
153196

154-
/**
155-
* Create appropriate State subclass based on the model type
156-
*/
157-
private State createStateForModel(Model model) {
158-
try {
159-
return model.createNewState();
160-
} catch (Exception e) {
161-
throw new RuntimeException("Failed to create state for model", e);
197+
// Getters for other services to access the initialized components
198+
public Options getOptions() {
199+
if (options == null) {
200+
throw new IllegalStateException("Model service not initialized yet");
162201
}
202+
return options;
163203
}
164204

165-
private String escapeJson(String str) {
166-
if (str == null) return "";
167-
return str.replace("\"", "\\\"")
168-
.replace("\n", "\\n")
169-
.replace("\r", "\\r")
170-
.replace("\t", "\\t")
171-
.replace("\\", "\\\\");
205+
public Model getModel() {
206+
if (model == null) {
207+
throw new IllegalStateException("Model service not initialized yet");
208+
}
209+
return model;
172210
}
173211
}

0 commit comments

Comments
 (0)