Skip to content

Commit 2b2ed4d

Browse files
[WIP] Update service logic
1 parent 55e4b34 commit 2b2ed4d

File tree

2 files changed

+58
-8
lines changed

2 files changed

+58
-8
lines changed

src/main/java/org/beehive/gpullama3/api/controller/ChatController.java

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,31 +19,44 @@ public class ChatController {
1919

2020
@PostMapping("/chat")
2121
public Map<String, String> chat(@RequestBody ChatRequest request) {
22-
logRequest("NON_STREAMING", request, 150, 0.7, 0.9);
22+
// Use request parameters with fallbacks to defaults
23+
int maxTokens = request.getMaxTokens() != null ? request.getMaxTokens() : 150;
24+
double temperature = request.getTemperature() != null ? request.getTemperature() : 0.7;
25+
double topP = request.getTopP() != null ? request.getTopP() : 0.9;
26+
27+
logRequest("NON_STREAMING", request, maxTokens, temperature, topP);
2328

2429
if (request.getMessage() == null || request.getMessage().trim().isEmpty()) {
2530
throw new IllegalArgumentException("Message cannot be empty");
2631
}
2732

28-
String response = llmService.generateResponse(request.getMessage(), request.getSystemMessage());
33+
String response = llmService.generateResponse(request.getMessage(), request.getSystemMessage(),
34+
maxTokens, temperature, topP);
2935

3036
return Map.of("response", response);
3137
}
3238

3339
@PostMapping(value = "/chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
3440
public SseEmitter streamChat(@RequestBody ChatRequest request) {
35-
logRequest("STREAMING", request, 150, 0.7, 0.9);
41+
// Use request parameters with fallbacks to defaults
42+
int maxTokens = request.getMaxTokens() != null ? request.getMaxTokens() : 150;
43+
double temperature = request.getTemperature() != null ? request.getTemperature() : 0.7;
44+
double topP = request.getTopP() != null ? request.getTopP() : 0.9;
45+
46+
logRequest("STREAMING", request, maxTokens, temperature, topP);
3647

3748
if (request.getMessage() == null || request.getMessage().trim().isEmpty()) {
3849
throw new IllegalArgumentException("Message cannot be empty");
3950
}
4051

4152
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
42-
llmService.generateStreamingResponse(request.getMessage(), request.getSystemMessage(), emitter);
53+
llmService.generateStreamingResponse(request.getMessage(), request.getSystemMessage(),
54+
emitter, maxTokens, temperature, topP);
4355

4456
return emitter;
4557
}
4658

59+
4760
@GetMapping("/health")
4861
public Map<String, String> health() {
4962
return Map.of("status", "healthy", "timestamp", String.valueOf(System.currentTimeMillis()));
@@ -69,11 +82,29 @@ private String truncate(String text, int maxLength) {
6982
public static class ChatRequest {
7083
private String message;
7184
private String systemMessage;
85+
private Integer maxTokens;
86+
private Double temperature;
87+
private Double topP;
88+
private Long seed;
7289

90+
// Getters and Setters
7391
public String getMessage() { return message; }
7492
public void setMessage(String message) { this.message = message; }
7593

7694
public String getSystemMessage() { return systemMessage; }
7795
public void setSystemMessage(String systemMessage) { this.systemMessage = systemMessage; }
96+
97+
public Integer getMaxTokens() { return maxTokens; }
98+
public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; }
99+
100+
public Double getTemperature() { return temperature; }
101+
public void setTemperature(Double temperature) { this.temperature = temperature; }
102+
103+
public Double getTopP() { return topP; }
104+
public void setTopP(Double topP) { this.topP = topP; }
105+
106+
public Long getSeed() { return seed; }
107+
public void setSeed(Long seed) { this.seed = seed; }
78108
}
109+
79110
}

src/main/java/org/beehive/gpullama3/api/service/LLMService.java

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,22 @@ public void init() {
6161
}
6262
}
6363

64+
/**
65+
* Generate response with default parameters.
66+
*/
6467
public String generateResponse(String message, String systemMessage) {
6568
return generateResponse(message, systemMessage, 150, 0.7, 0.9);
6669
}
6770

6871
public String generateResponse(String message, String systemMessage, int maxTokens, double temperature, double topP) {
72+
return generateResponse(message, systemMessage, maxTokens, temperature, topP, null);
73+
}
74+
75+
public String generateResponse(String message, String systemMessage, int maxTokens, double temperature, double topP, Long seed) {
6976
try {
7077
// Create sampler and state like runInstructOnce
71-
Sampler sampler = selectSampler(model.configuration().vocabularySize(), (float) temperature, (float) topP, System.currentTimeMillis());
78+
long actualSeed = seed != null ? seed : System.currentTimeMillis();
79+
Sampler sampler = selectSampler(model.configuration().vocabularySize(), (float) temperature, (float) topP, actualSeed);
7280
State state = model.createNewState();
7381

7482
// Use model's ChatFormat
@@ -115,7 +123,6 @@ public String generateResponse(String message, String systemMessage, int maxToke
115123
System.out.printf("COMPLETED tokens=%d duration=%dms rate=%.1f tok/s%n",
116124
generatedTokens.size(), duration, tokensPerSecond);
117125

118-
119126
String responseText = model.tokenizer().decode(generatedTokens);
120127

121128
// Add reasoning prefix for non-streaming if needed
@@ -132,9 +139,20 @@ public String generateResponse(String message, String systemMessage, int maxToke
132139
}
133140

134141
public void generateStreamingResponse(String message, String systemMessage, SseEmitter emitter) {
142+
generateStreamingResponse(message, systemMessage, emitter, 150, 0.7, 0.9);
143+
}
144+
145+
public void generateStreamingResponse(String message, String systemMessage, SseEmitter emitter,
146+
int maxTokens, double temperature, double topP) {
147+
generateStreamingResponse(message, systemMessage, emitter, maxTokens, temperature, topP, null);
148+
}
149+
150+
public void generateStreamingResponse(String message, String systemMessage, SseEmitter emitter,
151+
int maxTokens, double temperature, double topP, Long seed) {
135152
CompletableFuture.runAsync(() -> {
136153
try {
137-
Sampler sampler = selectSampler(model.configuration().vocabularySize(), 0.7f, 0.9f, System.currentTimeMillis());
154+
long actualSeed = seed != null ? seed : System.currentTimeMillis();
155+
Sampler sampler = selectSampler(model.configuration().vocabularySize(), (float) temperature, (float) topP, actualSeed);
138156
State state = model.createNewState();
139157

140158
// Use proper chat format like in runInstructOnce
@@ -164,13 +182,14 @@ public void generateStreamingResponse(String message, String systemMessage, SseE
164182
final int[] tokenCount = {0};
165183
long startTime = System.currentTimeMillis();
166184
List<Integer> generatedTokens = model.generateTokens(
167-
state, 0, promptTokens, stopTokens, 150, sampler, false,
185+
state, 0, promptTokens, stopTokens, maxTokens, sampler, false,
168186
token -> {
169187
try {
170188
// Only display tokens that should be displayed (like in your original)
171189
if (model.tokenizer().shouldDisplayToken(token)) {
172190
String tokenText = model.tokenizer().decode(List.of(token));
173191
emitter.send(SseEmitter.event().data(tokenText));
192+
//emitter.send(SseEmitter.event().comment("flush"));
174193
tokenCount[0]++;
175194
}
176195
} catch (Exception e) {

0 commit comments

Comments
 (0)