diff --git a/llama-tornado b/llama-tornado
index b59473f2..237cabf0 100755
--- a/llama-tornado
+++ b/llama-tornado
@@ -169,11 +169,14 @@ class LlamaRunner:
]
)
+ # Choose main class based on mode
+ main_class = "org.beehive.gpullama3.api.LLMApiApplication" if args.service else "org.beehive.gpullama3.LlamaApp"
+
module_config.extend(
[
"-cp",
self._find_llama_jar(),
- "org.beehive.gpullama3.LlamaApp",
+ main_class,
]
)
cmd.extend(module_config)
@@ -200,33 +203,36 @@ class LlamaRunner:
def _add_llama_args(self, cmd: List[str], args: argparse.Namespace) -> List[str]:
"""Add LLaMA-specific arguments to the command."""
+
+ # For service mode, only pass the model path and max-tokens
+ if hasattr(args, 'service') and args.service:
+ llama_args = [
+ "--model", args.model_path,
+ "--max-tokens", str(args.max_tokens),
+ ]
+ return cmd + llama_args
+
llama_args = [
- "-m",
- args.model_path,
- "--temperature",
- str(args.temperature),
- "--top-p",
- str(args.top_p),
- "--seed",
- str(args.seed),
- "--max-tokens",
- str(args.max_tokens),
- "--stream",
- str(args.stream).lower(),
- "--echo",
- str(args.echo).lower(),
+ "--model", args.model_path,
+ "--temperature", str(args.temperature),
+ "--top-p", str(args.top_p),
+ "--seed", str(args.seed),
+ "--max-tokens", str(args.max_tokens),
+ "--stream", str(args.stream).lower(),
+ "--echo", str(args.echo).lower(),
+ "--instruct" # Both modes use instruct
]
- if args.prompt:
- llama_args.extend(["-p", args.prompt])
+ # Only add prompt-related args for standalone mode
+ if not hasattr(args, 'service') or not args.service:
+ if hasattr(args, 'prompt') and args.prompt:
+ llama_args.extend(["-p", args.prompt])
- if args.system_prompt:
- llama_args.extend(["-sp", args.system_prompt])
+ if hasattr(args, 'system_prompt') and args.system_prompt:
+ llama_args.extend(["-sp", args.system_prompt])
- if args.interactive:
- llama_args.append("--interactive")
- elif args.instruct:
- llama_args.append("--instruct")
+ if hasattr(args, 'interactive') and args.interactive:
+ llama_args[-1] = "--interactive" # Replace --instruct
return cmd + llama_args
@@ -238,6 +244,30 @@ class LlamaRunner:
cmd = self._build_base_command(args)
cmd = self._add_llama_args(cmd, args)
+ # Show service-specific information
+ if args.service:
+ print("Starting GPULlama3.java REST API Service...")
+ print(f"Model: {args.model_path}")
+ # Display GPU/backend configuration
+ if args.use_gpu:
+ print(f"GPU Acceleration: Enabled ({args.backend.value.upper()} backend)")
+ print(f"GPU Memory: {args.gpu_memory}")
+ else:
+ print("GPU Acceleration: Disabled (CPU mode)")
+ print("API endpoints available at:")
+ print(" - http://localhost:8080/chat")
+ print(" - http://localhost:8080/chat/stream")
+ print(" - http://localhost:8080/health")
+ print("")
+ print("Example usage:")
+ print(' curl -X POST http://localhost:8080/chat \\')
+ print(' -H "Content-Type: application/json" \\')
+ print(' -d \'{"message": "Hello!"}\'')
+ print("")
+ print("Press Ctrl+C to stop the service")
+ print("-" * 60)
+
+
# Print command if requested (before verbose output)
if args.show_command:
print("Full Java command:")
@@ -390,6 +420,11 @@ def create_parser() -> argparse.ArgumentParser:
default=True,
help="Run in instruction mode (default)",
)
+ mode_group.add_argument(
+ "--service",
+ action="store_true",
+ help="Run as REST API service instead of standalone application"
+ )
# Hardware configuration
hw_group = parser.add_argument_group("Hardware Configuration")
diff --git a/pom.xml b/pom.xml
index 624fc671..a031256d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -12,6 +12,9 @@
21
21
UTF-8
+ 3.2.0
+ 3.0.0
+ 2.16.1
@@ -32,6 +35,26 @@
tornado-runtime
1.1.2-dev
+
+
+
+ org.springframework.boot
+ spring-boot-starter-web
+ ${spring.boot.version}
+
+
+
+ jakarta.annotation
+ jakarta.annotation-api
+ ${jakarta.version}
+
+
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+ ${jackson.version}
+
diff --git a/src/main/java/org/beehive/gpullama3/Options.java b/src/main/java/org/beehive/gpullama3/Options.java
index 98924481..606621a7 100644
--- a/src/main/java/org/beehive/gpullama3/Options.java
+++ b/src/main/java/org/beehive/gpullama3/Options.java
@@ -5,12 +5,15 @@
import java.nio.file.Paths;
public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive, float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo,
- boolean useTornadovm) {
+ boolean useTornadovm, boolean serviceMode) {
public static final int DEFAULT_MAX_TOKENS = 1024;
public Options {
- require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"");
+ // Skip prompt validation in service mode
+ if (!serviceMode) {
+ require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"");
+ }
require(0 <= temperature, "Invalid argument: --temperature must be non-negative");
require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]");
}
@@ -61,7 +64,7 @@ public static Options getDefaultOptions() {
boolean echo = false;
boolean useTornadoVM = getDefaultTornadoVM();
- return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM);
+ return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM, false);
}
public static Options parseOptions(String[] args) {
@@ -123,6 +126,59 @@ public static Options parseOptions(String[] args) {
useTornadovm = getDefaultTornadoVM();
}
- return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm);
+ return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm, false);
+ }
+
+ public static Options parseServiceOptions(String[] args) {
+ Path modelPath = null;
+ int maxTokens = 512; // Default context length
+ Boolean useTornadovm = null;
+
+ for (int i = 0; i < args.length; i++) {
+ String optionName = args[i];
+ require(optionName.startsWith("-"), "Invalid option %s", optionName);
+
+ String nextArg;
+ if (optionName.contains("=")) {
+ String[] parts = optionName.split("=", 2);
+ optionName = parts[0];
+ nextArg = parts[1];
+ } else {
+ if (i + 1 >= args.length) continue; // Skip if no next arg
+ nextArg = args[i + 1];
+ i += 1; // skip arg
+ }
+
+ // Only parse these options in service mode
+ switch (optionName) {
+ case "--model", "-m" -> modelPath = Paths.get(nextArg);
+ case "--max-tokens", "-n" -> maxTokens = Integer.parseInt(nextArg);
+ case "--use-tornadovm" -> useTornadovm = Boolean.parseBoolean(nextArg);
+ }
+ }
+
+ require(modelPath != null, "Missing argument: --model is required");
+
+ // Do not use tornado by default
+ if (useTornadovm == null) {
+ useTornadovm = false;
+ }
+
+ // Create service-mode Options object
+ return new Options(
+ modelPath,
+ null, // prompt - not used in service
+ null, // systemPrompt - handled per request
+ null, // suffix - not used
+ false, // interactive - not used in service
+ 0.7f, // temperature - default, overridden per request
+ 0.9f, // topp - default, overridden per request
+ System.nanoTime(), // seed - default
+ maxTokens,
+ false, // stream - handled per request
+ false, // echo - not used in service
+ useTornadovm,
+ true
+ );
}
}
diff --git a/src/main/java/org/beehive/gpullama3/api/LLMApiApplication.java b/src/main/java/org/beehive/gpullama3/api/LLMApiApplication.java
new file mode 100644
index 00000000..2ff1d60d
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/api/LLMApiApplication.java
@@ -0,0 +1,16 @@
+package org.beehive.gpullama3.api;
+
+import org.springframework.boot.SpringApplication;
+import org.springframework.boot.autoconfigure.SpringBootApplication;
+
+@SpringBootApplication(scanBasePackages = "org.beehive.gpullama3")
+public class LLMApiApplication {
+
+ public static void main(String[] args) {
+ System.out.println("Starting TornadoVM LLM API Server...");
+ System.out.println("Command line arguments: " + String.join(" ", args));
+
+ // Let Options.parseOptions() handle validation - no duplication
+ SpringApplication.run(LLMApiApplication.class, args);
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/beehive/gpullama3/api/config/ModelConfiguration.java b/src/main/java/org/beehive/gpullama3/api/config/ModelConfiguration.java
new file mode 100644
index 00000000..72e5a624
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/api/config/ModelConfiguration.java
@@ -0,0 +1,27 @@
+package org.beehive.gpullama3.api.config;
+
+import org.beehive.gpullama3.api.service.LLMService;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.Options;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+
+@Configuration
+public class ModelConfiguration {
+
+ /**
+ * Expose Model as a Spring bean using the initialized service
+ */
+ @Bean
+ public Model model(LLMService llmService) {
+ return llmService.getModel();
+ }
+
+ /**
+ * Expose Options as a Spring bean using the initialized service
+ */
+ @Bean
+ public Options options(LLMService llmService) {
+ return llmService.getOptions();
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/beehive/gpullama3/api/controller/ChatController.java b/src/main/java/org/beehive/gpullama3/api/controller/ChatController.java
new file mode 100644
index 00000000..ba40f274
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/api/controller/ChatController.java
@@ -0,0 +1,116 @@
+package org.beehive.gpullama3.api.controller;
+
+import org.beehive.gpullama3.api.service.LLMService;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.http.MediaType;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.PostMapping;
+import org.springframework.web.bind.annotation.RequestBody;
+import org.springframework.web.bind.annotation.RestController;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+
+import java.util.Map;
+
+@RestController
+public class ChatController {
+
+ @Autowired
+ private LLMService llmService;
+
+ @PostMapping("/chat")
+ public Map chat(@RequestBody ChatRequest request) {
+ // Use request parameters with fallbacks to defaults
+ int maxTokens = request.getMaxTokens() != null ? request.getMaxTokens() : 150;
+ double temperature = request.getTemperature() != null ? request.getTemperature() : 0.7;
+ double topP = request.getTopP() != null ? request.getTopP() : 0.9;
+
+ logRequest("NON_STREAMING", request, maxTokens, temperature, topP);
+
+ if (request.getMessage() == null || request.getMessage().trim().isEmpty()) {
+ throw new IllegalArgumentException("Message cannot be empty");
+ }
+
+ String response = llmService.generateResponse(request.getMessage(), request.getSystemMessage(),
+ maxTokens, temperature, topP);
+
+ return Map.of("response", response);
+ }
+
+ @PostMapping(value = "/chat/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
+ public SseEmitter streamChat(@RequestBody ChatRequest request) {
+ // Use request parameters with fallbacks to defaults
+ int maxTokens = request.getMaxTokens() != null ? request.getMaxTokens() : 150;
+ double temperature = request.getTemperature() != null ? request.getTemperature() : 0.7;
+ double topP = request.getTopP() != null ? request.getTopP() : 0.9;
+
+ logRequest("STREAMING", request, maxTokens, temperature, topP);
+
+ if (request.getMessage() == null || request.getMessage().trim().isEmpty()) {
+ throw new IllegalArgumentException("Message cannot be empty");
+ }
+
+ SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
+ llmService.generateStreamingResponse(
+ request.getMessage(),
+ request.getSystemMessage(),
+ emitter,
+ maxTokens,
+ temperature,
+ topP,
+ request.getSeed());
+
+ return emitter;
+ }
+
+
+ @GetMapping("/health")
+ public Map health() {
+ return Map.of("status", "healthy", "timestamp", String.valueOf(System.currentTimeMillis()));
+ }
+
+ private void logRequest(String type, ChatRequest request, int maxTokens, double temperature, double topP) {
+ System.out.printf("REQUEST [%s] user='%s' system='%s' maxTokens=%d temp=%.2f topP=%.2f%n",
+ type,
+ truncate(request.getMessage(), 100),
+ request.getSystemMessage() != null ? truncate(request.getSystemMessage(), 40) : "none",
+ maxTokens,
+ temperature,
+ topP
+ );
+ }
+
+ private String truncate(String text, int maxLength) {
+ if (text == null) return "null";
+ return text.length() > maxLength ? text.substring(0, maxLength) + "..." : text;
+ }
+
+ // Simple request class for custom parameters
+ public static class ChatRequest {
+ private String message;
+ private String systemMessage;
+ private Integer maxTokens;
+ private Double temperature;
+ private Double topP;
+ private Long seed;
+
+ // Getters and Setters
+ public String getMessage() { return message; }
+ public void setMessage(String message) { this.message = message; }
+
+ public String getSystemMessage() { return systemMessage; }
+ public void setSystemMessage(String systemMessage) { this.systemMessage = systemMessage; }
+
+ public Integer getMaxTokens() { return maxTokens; }
+ public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; }
+
+ public Double getTemperature() { return temperature; }
+ public void setTemperature(Double temperature) { this.temperature = temperature; }
+
+ public Double getTopP() { return topP; }
+ public void setTopP(Double topP) { this.topP = topP; }
+
+ public Long getSeed() { return seed; }
+ public void setSeed(Long seed) { this.seed = seed; }
+ }
+
+}
diff --git a/src/main/java/org/beehive/gpullama3/api/model/CompletionRequest.java b/src/main/java/org/beehive/gpullama3/api/model/CompletionRequest.java
new file mode 100644
index 00000000..56a304e4
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/api/model/CompletionRequest.java
@@ -0,0 +1,59 @@
+package org.beehive.gpullama3.api.model;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import java.util.List;
+
+public class CompletionRequest {
+ private String model = "gpullama3";
+ private String prompt;
+
+ @JsonProperty("max_tokens")
+ private Integer maxTokens = 100;
+
+ private Double temperature = 0.7;
+
+ @JsonProperty("top_p")
+ private Double topP = 0.9;
+
+ @JsonProperty("stop")
+ private List stopSequences;
+
+ private Boolean stream = false;
+
+ // Constructors
+ public CompletionRequest() {}
+
+ // Getters and Setters
+ public String getModel() { return model; }
+ public void setModel(String model) { this.model = model; }
+
+ public String getPrompt() { return prompt; }
+ public void setPrompt(String prompt) { this.prompt = prompt; }
+
+ public Integer getMaxTokens() { return maxTokens; }
+ public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; }
+
+ public Double getTemperature() { return temperature; }
+ public void setTemperature(Double temperature) { this.temperature = temperature; }
+
+ public Double getTopP() { return topP; }
+ public void setTopP(Double topP) { this.topP = topP; }
+
+ public List getStopSequences() { return stopSequences; }
+ public void setStopSequences(List stopSequences) { this.stopSequences = stopSequences; }
+
+ public Boolean getStream() { return stream; }
+ public void setStream(Boolean stream) { this.stream = stream; }
+
+ @Override
+ public String toString() {
+ return "CompletionRequest{" +
+ "model='" + model + '\'' +
+ ", prompt='" + (prompt != null ? prompt.substring(0, Math.min(50, prompt.length())) + "..." : null) + '\'' +
+ ", maxTokens=" + maxTokens +
+ ", temperature=" + temperature +
+ ", topP=" + topP +
+ ", stream=" + stream +
+ '}';
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/beehive/gpullama3/api/model/CompletionResponse.java b/src/main/java/org/beehive/gpullama3/api/model/CompletionResponse.java
new file mode 100644
index 00000000..b3cecf3a
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/api/model/CompletionResponse.java
@@ -0,0 +1,93 @@
+package org.beehive.gpullama3.api.model;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import java.util.List;
+
+public class CompletionResponse {
+ private String id;
+ private String object = "text_completion";
+ private Long created;
+ private String model;
+ private List choices;
+ private Usage usage;
+
+ public static class Choice {
+ private String text;
+ private Integer index;
+
+ @JsonProperty("finish_reason")
+ private String finishReason;
+
+ public Choice() {}
+
+ public Choice(String text, Integer index, String finishReason) {
+ this.text = text;
+ this.index = index;
+ this.finishReason = finishReason;
+ }
+
+ // Getters and Setters
+ public String getText() { return text; }
+ public void setText(String text) { this.text = text; }
+
+ public Integer getIndex() { return index; }
+ public void setIndex(Integer index) { this.index = index; }
+
+ public String getFinishReason() { return finishReason; }
+ public void setFinishReason(String finishReason) { this.finishReason = finishReason; }
+ }
+
+ public static class Usage {
+ @JsonProperty("prompt_tokens")
+ private Integer promptTokens;
+
+ @JsonProperty("completion_tokens")
+ private Integer completionTokens;
+
+ @JsonProperty("total_tokens")
+ private Integer totalTokens;
+
+ public Usage() {}
+
+ public Usage(Integer promptTokens, Integer completionTokens) {
+ this.promptTokens = promptTokens;
+ this.completionTokens = completionTokens;
+ this.totalTokens = promptTokens + completionTokens;
+ }
+
+ // Getters and Setters
+ public Integer getPromptTokens() { return promptTokens; }
+ public void setPromptTokens(Integer promptTokens) { this.promptTokens = promptTokens; }
+
+ public Integer getCompletionTokens() { return completionTokens; }
+ public void setCompletionTokens(Integer completionTokens) { this.completionTokens = completionTokens; }
+
+ public Integer getTotalTokens() { return totalTokens; }
+ public void setTotalTokens(Integer totalTokens) { this.totalTokens = totalTokens; }
+ }
+
+ // Constructors
+ public CompletionResponse() {
+ this.id = "cmpl-" + System.currentTimeMillis();
+ this.created = System.currentTimeMillis() / 1000;
+ }
+
+ // Getters and Setters
+ public String getId() { return id; }
+ public void setId(String id) { this.id = id; }
+
+ public String getObject() { return object; }
+ public void setObject(String object) { this.object = object; }
+
+ public Long getCreated() { return created; }
+ public void setCreated(Long created) { this.created = created; }
+
+ public String getModel() { return model; }
+ public void setModel(String model) { this.model = model; }
+
+ public List getChoices() { return choices; }
+ public void setChoices(List choices) { this.choices = choices; }
+
+ public Usage getUsage() { return usage; }
+ public void setUsage(Usage usage) { this.usage = usage; }
+}
\ No newline at end of file
diff --git a/src/main/java/org/beehive/gpullama3/api/service/LLMService.java b/src/main/java/org/beehive/gpullama3/api/service/LLMService.java
new file mode 100644
index 00000000..9fca3e2a
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/api/service/LLMService.java
@@ -0,0 +1,243 @@
+package org.beehive.gpullama3.api.service;
+
+import jakarta.annotation.PostConstruct;
+import org.beehive.gpullama3.Options;
+import org.beehive.gpullama3.auxiliary.LastRunMetrics;
+import org.beehive.gpullama3.inference.sampler.Sampler;
+import org.beehive.gpullama3.inference.state.State;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.format.ChatFormat;
+import org.beehive.gpullama3.model.loader.ModelLoader;
+import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
+import org.springframework.boot.ApplicationArguments;
+import org.springframework.stereotype.Service;
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.IntConsumer;
+
+import static org.beehive.gpullama3.inference.sampler.Sampler.selectSampler;
+import static org.beehive.gpullama3.model.loader.ModelLoader.loadModel;
+
+@Service
+public class LLMService {
+
+ private final ApplicationArguments args;
+
+ private Options options;
+ private Model model;
+
+ public LLMService(ApplicationArguments args) {
+ this.args = args;
+ }
+
+ @PostConstruct
+ public void init() {
+ try {
+ System.out.println("Initializing LLM service...");
+
+ // Step 1: Parse service options
+ System.out.println("Step 1: Parsing service options...");
+ options = Options.parseServiceOptions(args.getSourceArgs());
+ System.out.println("Model path: " + options.modelPath());
+ System.out.println("Context length: " + options.maxTokens());
+
+ // Step 2: Load model weights
+ System.out.println("\nStep 2: Loading model...");
+ System.out.println("Loading model from: " + options.modelPath());
+ model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true, options.useTornadovm());
+ System.out.println("ā Model loaded successfully");
+ System.out.println(" Model type: " + model.getClass().getSimpleName());
+ System.out.println(" Vocabulary size: " + model.configuration().vocabularySize());
+ System.out.println(" Context length: " + model.configuration().contextLength());
+
+ System.out.println("\nā Model service initialization completed successfully!");
+ System.out.println("=== Ready to serve requests ===\n");
+
+ } catch (Exception e) {
+ System.err.println("ā Failed to initialize model service: " + e.getMessage());
+ e.printStackTrace();
+ throw new RuntimeException("Model initialization failed", e);
+ }
+ }
+
+ /**
+ * Generate response with default parameters.
+ */
+ public String generateResponse(String message, String systemMessage) {
+ return generateResponse(message, systemMessage, 150, 0.7, 0.9);
+ }
+
+ public String generateResponse(String message, String systemMessage, int maxTokens, double temperature, double topP) {
+ return generateResponse(message, systemMessage, maxTokens, temperature, topP, null);
+ }
+
+ public String generateResponse(String message, String systemMessage, int maxTokens, double temperature, double topP, Long seed) {
+ try {
+ // Create sampler and state like runInstructOnce
+ long actualSeed = seed != null ? seed : System.currentTimeMillis();
+ Sampler sampler = selectSampler(model.configuration().vocabularySize(), (float) temperature, (float) topP, actualSeed);
+ State state = model.createNewState();
+
+ // Use model's ChatFormat
+ ChatFormat chatFormat = model.chatFormat();
+ List promptTokens = new ArrayList<>();
+
+ // Add begin of text if needed
+ if (model.shouldAddBeginOfText()) {
+ promptTokens.add(chatFormat.getBeginOfText());
+ }
+
+ // Add system message properly formatted
+ if (model.shouldAddSystemPrompt() && systemMessage != null && !systemMessage.trim().isEmpty()) {
+ promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, systemMessage)));
+ }
+
+ // Add user message properly formatted
+ promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, message)));
+ promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
+
+ // Handle reasoning tokens if needed (for Deepseek-R1-Distill-Qwen)
+ if (model.shouldIncludeReasoning()) {
+ List thinkStartTokens = model.tokenizer().encode("\n", model.tokenizer().getSpecialTokens().keySet());
+ promptTokens.addAll(thinkStartTokens);
+ }
+
+ // Use proper stop tokens from chat format
+ Set stopTokens = chatFormat.getStopTokens();
+
+ long startTime = System.currentTimeMillis();
+
+ // Use CPU path for now (GPU path disabled as noted)
+ List generatedTokens = model.generateTokens(
+ state, 0, promptTokens, stopTokens, maxTokens, sampler, false, token -> {}
+ );
+
+ // Remove stop tokens if present
+ if (!generatedTokens.isEmpty() && stopTokens.contains(generatedTokens.getLast())) {
+ generatedTokens.removeLast();
+ }
+
+ long duration = System.currentTimeMillis() - startTime;
+ double tokensPerSecond = generatedTokens.size() * 1000.0 / duration;
+ System.out.printf("COMPLETED tokens=%d duration=%dms rate=%.1f tok/s%n",
+ generatedTokens.size(), duration, tokensPerSecond);
+
+ String responseText = model.tokenizer().decode(generatedTokens);
+
+ // Add reasoning prefix for non-streaming if needed
+ if (model.shouldIncludeReasoning()) {
+ responseText = "\n" + responseText;
+ }
+
+ return responseText;
+
+ } catch (Exception e) {
+ System.err.println("FAILED " + e.getMessage());
+ throw new RuntimeException("Failed to generate response", e);
+ }
+ }
+
+ public void generateStreamingResponse(String message, String systemMessage, SseEmitter emitter,
+ int maxTokens, double temperature, double topP, Long seed) {
+ CompletableFuture.runAsync(() -> {
+ try {
+ long actualSeed = seed != null ? seed : System.currentTimeMillis();
+ Sampler sampler = selectSampler(model.configuration().vocabularySize(), (float) temperature, (float) topP, actualSeed);
+ State state = model.createNewState();
+
+ // Use proper chat format like in runInstructOnce
+ ChatFormat chatFormat = model.chatFormat();
+ List promptTokens = new ArrayList<>();
+
+ if (model.shouldAddBeginOfText()) {
+ promptTokens.add(chatFormat.getBeginOfText());
+ }
+
+ if (model.shouldAddSystemPrompt() && systemMessage != null && !systemMessage.trim().isEmpty()) {
+ promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, systemMessage)));
+ }
+
+ promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, message)));
+ promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
+
+ // Include reasoning for Deepseek-R1-Distill-Qwen
+ if (model.shouldIncludeReasoning()) {
+ List thinkStartTokens = model.tokenizer().encode("\n", model.tokenizer().getSpecialTokens().keySet());
+ promptTokens.addAll(thinkStartTokens);
+ // We are in streaming, immediately output the think start
+ emitter.send(SseEmitter.event().data("\n"));
+ }
+
+ Set stopTokens = chatFormat.getStopTokens();
+
+ final int[] tokenCount = {0};
+ IntConsumer tokenConsumer = token -> {
+ try {
+ // Only display tokens that should be displayed (like in your original)
+ if (model.tokenizer().shouldDisplayToken(token)) {
+ String tokenText = model.tokenizer().decode(List.of(token));
+ emitter.send(SseEmitter.event().data(tokenText));
+ tokenCount[0]++;
+ }
+ } catch (Exception e) {
+ emitter.completeWithError(e);
+ }
+ };
+
+ // Initialize TornadoVM plan once per request if GPU path is enabled
+ TornadoVMMasterPlan tornadoVMPlan = null;
+ if (options.useTornadovm()) {
+ tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model);
+ }
+
+ // Select execution path
+ if (options.useTornadovm()) {
+ // GPU path
+ model.generateTokensGPU(state, 0, promptTokens, stopTokens, maxTokens, sampler, false, tokenConsumer, tornadoVMPlan);
+ } else {
+ // CPU path
+ model.generateTokens(state, 0, promptTokens, stopTokens, maxTokens, sampler, false, tokenConsumer);
+ }
+
+ LastRunMetrics metrics = LastRunMetrics.getMetrics();
+ double seconds = metrics.totalSeconds();
+ int tokens = metrics.totalTokens();
+ double tokensPerSecond = tokens / seconds;
+ System.out.printf("COMPLETED - [ achieved tok/s: %.2f. Tokens: %d, seconds: %.2f ]\n",
+ tokensPerSecond, tokens, seconds);
+
+ // Send metrics as named event before [DONE]
+ String metricsString = LastRunMetrics.getMetricsString();
+ if (!metricsString.isEmpty()) {
+ emitter.send(SseEmitter.event().name("metrics").data(metricsString));
+ }
+
+ emitter.send(SseEmitter.event().data("[DONE]"));
+ emitter.complete();
+
+ } catch (Exception e) {
+ System.err.println("FAILED " + e.getMessage());
+ emitter.completeWithError(e);
+ }
+ });
+ }
+
+ // Getters for other services to access the initialized components
+ public Options getOptions() {
+ if (options == null) {
+ throw new IllegalStateException("Model service not initialized yet");
+ }
+ return options;
+ }
+
+ public Model getModel() {
+ if (model == null) {
+ throw new IllegalStateException("Model service not initialized yet");
+ }
+ return model;
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/beehive/gpullama3/auxiliary/LastRunMetrics.java b/src/main/java/org/beehive/gpullama3/auxiliary/LastRunMetrics.java
index 0d411801..6dfb0d58 100644
--- a/src/main/java/org/beehive/gpullama3/auxiliary/LastRunMetrics.java
+++ b/src/main/java/org/beehive/gpullama3/auxiliary/LastRunMetrics.java
@@ -21,6 +21,15 @@ public static void setMetrics(int tokens, double seconds) {
latestMetrics = new LastRunMetrics(tokens, seconds);
}
+ public static LastRunMetrics getMetrics() {
+ return latestMetrics;
+ }
+
+ public static String getMetricsString() {
+ double tokensPerSecond = latestMetrics.totalTokens() / latestMetrics.totalSeconds();
+ return String.format("\n\nachieved tok/s: %.2f. Tokens: %d, seconds: %.2f\n", tokensPerSecond, latestMetrics.totalTokens(), latestMetrics.totalSeconds());
+ }
+
/**
* Prints the metrics from the latest run to stderr
*/